diff --git a/src/nccl_ofi_api.c b/src/nccl_ofi_api.c index 5e5cc5275..957ab295c 100644 --- a/src/nccl_ofi_api.c +++ b/src/nccl_ofi_api.c @@ -365,15 +365,16 @@ ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size, const nccl_ofi_mr_ckey_t cache_key = nccl_ofi_mr_ckey_mk_vec(data, size); #endif + nccl_net_ofi_send_comm_t *send_comm = NULL; + nccl_net_ofi_recv_comm_t *recv_comm = NULL; + switch (base_comm->type) { - case NCCL_NET_OFI_SEND_COMM:; - nccl_net_ofi_send_comm_t *send_comm = - (nccl_net_ofi_send_comm_t *)base_comm; + case NCCL_NET_OFI_SEND_COMM: + send_comm = (nccl_net_ofi_send_comm_t *)base_comm; ret = send_comm->regMr(send_comm, &cache_key, type, mhandle); break; - case NCCL_NET_OFI_RECV_COMM:; - nccl_net_ofi_recv_comm_t *recv_comm = - (nccl_net_ofi_recv_comm_t *)base_comm; + case NCCL_NET_OFI_RECV_COMM: + recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm; ret = recv_comm->regMr(recv_comm, &cache_key, type, mhandle); break; default: @@ -397,16 +398,16 @@ ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle) } int ret = 0; + nccl_net_ofi_send_comm_t *send_comm = NULL; + nccl_net_ofi_recv_comm_t *recv_comm = NULL; switch (base_comm->type) { - case NCCL_NET_OFI_SEND_COMM:; - nccl_net_ofi_send_comm_t *send_comm = - (nccl_net_ofi_send_comm_t *)base_comm; + case NCCL_NET_OFI_SEND_COMM: + send_comm = (nccl_net_ofi_send_comm_t *)base_comm; ret = send_comm->deregMr(send_comm, (nccl_net_ofi_mr_handle_t *)mhandle); break; - case NCCL_NET_OFI_RECV_COMM:; - nccl_net_ofi_recv_comm_t *recv_comm = - (nccl_net_ofi_recv_comm_t *)base_comm; + case NCCL_NET_OFI_RECV_COMM: + recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm; ret = recv_comm->deregMr(recv_comm, (nccl_net_ofi_mr_handle_t *)mhandle); break; default: diff --git a/src/nccl_ofi_ep_addr_list.c b/src/nccl_ofi_ep_addr_list.c index 0dda529d0..ba43bc355 100644 --- a/src/nccl_ofi_ep_addr_list.c +++ b/src/nccl_ofi_ep_addr_list.c @@ -134,6 +134,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list, size_t addr_size) { int ret = 0; + ep_pair_list_elem_t *new_pair = NULL; if (addr_size > ep_list->max_addr_size) { NCCL_OFI_WARN("Address size %zu > max size (%zu)", addr_size, @@ -155,8 +156,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list, memcpy(new_addr->addr, addr_in, addr_size); zero_pad_address(new_addr->addr, addr_size, ep_list->max_addr_size); - ep_pair_list_elem_t *new_pair = (ep_pair_list_elem_t *) - malloc(sizeof(*new_pair)); + new_pair = (ep_pair_list_elem_t *)malloc(sizeof(*new_pair)); if (!new_pair) { NCCL_OFI_WARN("Failed to allocate new ep list element"); free(new_addr); diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 891e2ee94..af13326bc 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -136,6 +136,8 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) int ret = 0; const char *provider_filter = NULL; nccl_net_ofi_plugin_t *plugin; + nccl_net_ofi_ep_t *base_ep = NULL; + nccl_net_ofi_device_t *device = NULL; NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Initializing " PACKAGE_STRING); @@ -275,8 +277,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) * resources. This initialization happens once per process, and thus it * does not matter which device is used to create the endpoint. */ - nccl_net_ofi_device_t *device = plugin->get_device(plugin, 0); - nccl_net_ofi_ep_t *base_ep = NULL; + device = plugin->get_device(plugin, 0); ret = device->get_ep(device, &base_ep); if (ret != 0) { diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index b4dceca8e..712da9717 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -3164,7 +3164,13 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int ret = 0; nccl_net_ofi_rdma_req_t *req = NULL; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; + rdma_req_recv_data_t *recv_data = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; + int dev_id = 0; nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles; + uint16_t msg_seq_num = 0; + bool eager = false; assert(r_comm != NULL); @@ -3181,12 +3187,12 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - int dev_id = r_comm->base.base.dev_id; + dev_id = r_comm->base.base.dev_id; - nccl_net_ofi_rdma_ep_t * ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); assert(device != NULL); ret = process_cq_if_pending(ep); @@ -3195,13 +3201,14 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, *base_req = NULL; ret = 0; goto error; - } else if (ret != 0) { + } + if (ret != 0) { goto error; } - uint16_t msg_seq_num = r_comm->next_msg_seq_num; + msg_seq_num = r_comm->next_msg_seq_num; - bool eager = false; + eager = false; void *elem; nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t msg_stat; @@ -3240,7 +3247,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - rdma_req_recv_data_t *recv_data = get_recv_data(req); + recv_data = get_recv_data(req); if (eager) { nccl_net_ofi_rdma_req_t *bounce_req = (nccl_net_ofi_rdma_req_t *)elem; @@ -3427,6 +3434,7 @@ static int alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_comm, int d static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) { + nccl_net_ofi_rdma_device_t *device = NULL; int ret = 0; /* Retrieve and validate endpoint */ @@ -3437,7 +3445,7 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) return ret; } - nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)base_ep->device; + device = (nccl_net_ofi_rdma_device_t *)base_ep->device; if (r_comm->send_close_req != NULL) { ret = r_comm->send_close_req->free(r_comm->send_close_req, false); @@ -3810,6 +3818,12 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, nccl_net_ofi_req_t **base_req) { int ret = 0; + int flush_n = 0; + bool network_busy = false; + int dev_id = 0; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; + nccl_net_ofi_scheduler_t *scheduler = NULL; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; @@ -3826,19 +3840,19 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - int dev_id = recv_comm->base.dev_id; + dev_id = recv_comm->base.dev_id; - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); assert(device != NULL); - nccl_net_ofi_scheduler_t *scheduler = device->scheduler; + scheduler = device->scheduler; assert(scheduler != NULL); /* Process any pending requests */ - bool network_busy = false; + network_busy = false; rc = process_cq_if_pending(ep); if (rc == -EAGAIN) { /* Network is still busy. */ @@ -3868,7 +3882,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, * Find the non-zero request for which we will issue flush. * A single operation can flush all request at once. */ - int flush_n = -1; + flush_n = -1; for (int recv_n = 0; recv_n < n; recv_n++) { if (sizes[recv_n] != 0) { flush_n = recv_n; @@ -4032,6 +4046,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; nccl_net_ofi_rdma_req_t *req = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; assert(r_comm != NULL); /* Support only NCCL_OFI_MAX_REQUESTS inflight requests. */ @@ -4042,7 +4057,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size goto error; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -4114,7 +4129,9 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen { int ret = 0; + int comm_id = 0; nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; int dev_id = device->base.dev_id; int num_rails = l_comm_ep->num_rails; @@ -4153,7 +4170,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen r_comm->n_ctrl_delivered = 0; /* Allocate recv communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { r_comm->local_comm_id = ~0; goto error; @@ -4212,7 +4229,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen r_comm->base.base.ep = &l_comm_ep->base; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; /* Add ourselves to ep's lookup array */ set_comm(device, r_comm->local_comm_id, &r_comm->base.base); @@ -4705,6 +4722,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, { int ret = 0; nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL; + int comm_id = 0; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; @@ -4743,7 +4761,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->leader_local_ep = ep->control_rail.ofi_ep; /* Allocate listen communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { l_comm->comm_id = ~0; ret = comm_id; @@ -5339,10 +5357,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t int ret = 0; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; + nccl_net_ofi_rdma_ep_t *ep = NULL; nccl_net_ofi_rdma_req_t *req = NULL; uint16_t msg_seq_num = s_comm->next_msg_seq_num; bool polled_cq = false; bool have_ctrl = false; + bool eager = false; + int dev_id = 0; assert(s_comm != NULL); @@ -5360,9 +5381,9 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } - int dev_id = s_comm->base.base.dev_id; + dev_id = s_comm->base.base.dev_id; - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -5371,7 +5392,8 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t *base_req = NULL; ret = 0; goto error; - } else if (ret != 0) { + } + if (ret != 0) { goto error; } @@ -5379,6 +5401,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t * TODO: Use NCCL provided tags when using grouped receives aka * props->maxRecvs > 1. */ + + have_ctrl = false; + msg_seq_num = s_comm->next_msg_seq_num; + void *elem; nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t msg_stat; @@ -5433,7 +5459,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t } /* Determine if this should be sent eagerly. */ - bool eager = false; + eager = false; if ((!have_ctrl && size <= eager_max_size) || (size == 0)) { eager = true; @@ -5741,6 +5767,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t int ret = 0; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm; nccl_net_ofi_rdma_req_t *req = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; assert(s_comm != NULL); @@ -5752,7 +5779,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t goto error; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -5860,6 +5887,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t **s_comm) { int ret = 0; + int comm_id = 0; fi_addr_t remote_addr; nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; @@ -5916,7 +5944,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->remote_comm_id = handle->comm_id; /* Allocate send communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { ret_s_comm->local_comm_id = ~0; ret = comm_id; @@ -6111,6 +6139,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, nccl_net_ofi_send_comm_t **send_comm) { int ret = 0; + nccl_net_ofi_rdma_req_state_t conn_resp_req_state; nccl_net_ofi_rdma_req_state_t conn_msg_state; *send_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = @@ -6238,7 +6267,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, } nccl_net_ofi_mutex_lock(&s_comm->conn_resp_req->req_lock); - nccl_net_ofi_rdma_req_state_t conn_resp_req_state = s_comm->conn_resp_req->state; + conn_resp_req_state = s_comm->conn_resp_req->state; nccl_net_ofi_mutex_unlock(&s_comm->conn_resp_req->req_lock); /* Wait until conn resp message is received */ @@ -6414,10 +6443,11 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device, static int release_ep(nccl_net_ofi_ep_t *base_ep) { int ret = 0; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; /* Validate device */ - nccl_net_ofi_rdma_ep_t *ep = - (nccl_net_ofi_rdma_ep_t*)base_ep; + ep = (nccl_net_ofi_rdma_ep_t *)base_ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -6425,7 +6455,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep) } /* Validate device */ - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -6605,10 +6635,10 @@ static inline int get_ep(nccl_net_ofi_device_t *base_dev, nccl_net_ofi_ep_t **ba int ret = 0; long thread_id; nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; /* Retrieve and validate device */ - nccl_net_ofi_rdma_device_t *device = - (nccl_net_ofi_rdma_device_t*)base_dev; + device = (nccl_net_ofi_rdma_device_t *)base_dev; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -6929,8 +6959,9 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info_list, nccl_ofi_topo_t *topo, size_t rr_threshold) { - int ret; - + int ret = 0; + bool provide_own_mr_key = false; + int length = 0; nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)calloc(1, sizeof(nccl_net_ofi_rdma_device_t)); if (device == NULL) { @@ -6961,7 +6992,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, } /* Ensure that number of rails are the same across devices */ - int length = ofi_info_list_length(info_list); + length = ofi_info_list_length(info_list); if (topo->max_group_size != length) { NCCL_OFI_WARN("Wrong number of NICs for device %i. Expected %i but got %i", dev_id, topo->max_group_size, length); @@ -7038,7 +7069,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, } /* Initialize mr key pool */ - bool provide_own_mr_key = true; + provide_own_mr_key = true; ret = nccl_ofi_mr_keys_need_own_key(info_list, &provide_own_mr_key); if (ret != 0) { NCCL_OFI_WARN("MR key config parsing failed: %s", @@ -7259,6 +7290,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, nccl_net_ofi_rdma_plugin_t *plugin = NULL; nccl_ofi_topo_t *topo = NULL; struct fi_info *hints; + uint32_t api_version = 0; *found_multiple_rails = false; @@ -7270,7 +7302,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } get_hints(hints); - uint32_t api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18); + api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18); ret = nccl_ofi_ofiutils_get_providers(provider_filter, api_version, hints, &provider_list, &num_providers); if (ret == 0) { diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 795dbafad..e97dec021 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -360,6 +360,8 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) { int ret = 0; nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)base_req; + nccl_net_ofi_sendrecv_device_t *device = NULL; + nccl_net_ofi_sendrecv_ep_t *ep = NULL; /* Retrieve and validate comm */ nccl_net_ofi_comm_t *base_comm = req->comm; @@ -370,8 +372,7 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } /* Retrieve and validate endpoint */ - nccl_net_ofi_sendrecv_ep_t *ep = - (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; + ep = (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -379,8 +380,7 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -760,6 +760,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, /* Retrieve and validate endpoint */ nccl_net_ofi_sendrecv_ep_t *ep = (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; + nccl_ofi_idpool_t *key_pool = NULL; if (OFI_UNLIKELY(ep == NULL)) { NCCL_OFI_WARN("Invalid endpoint provided"); return -EINVAL; @@ -795,7 +796,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, } /* Cache miss */ - nccl_ofi_idpool_t *key_pool = &device->key_pool; + key_pool = &device->key_pool; struct fid_domain *domain; domain = get_domain_from_endpoint(ep); ret = reg_mr_base(domain, ep->ofi_ep, key_pool, @@ -889,14 +890,15 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int ret = 0; ssize_t rc = 0; nccl_net_ofi_sendrecv_req_t *req = NULL; + nccl_net_ofi_sendrecv_ep_t *ep = NULL; + nccl_net_ofi_sendrecv_device_t *device = NULL; nccl_net_ofi_sendrecv_recv_comm_t *r_comm = (nccl_net_ofi_sendrecv_recv_comm_t *)recv_comm; int dev_id = r_comm->base.base.dev_id; struct fid_mr **mr_handles = (struct fid_mr **)mhandles; /* Retrieve and validate endpoint */ - nccl_net_ofi_sendrecv_ep_t * ep = - (nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -904,8 +906,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -1051,6 +1052,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, void *data = NULL; void *flush_mr_desc = NULL; int dev_id = recv_comm->base.dev_id; + int flush_n = -1; struct fid_mr **mr_handles = (struct fid_mr **)mhandles; if (ofi_nccl_gdr_flush_disable() || support_gdr == GDR_UNSUPPORTED) @@ -1073,7 +1075,6 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, * Find the non-zero request for which we will issue flush. * A single operation can flush all request at once. */ - int flush_n = -1; for (int recv_n = 0; recv_n < n; recv_n++) { if (sizes[recv_n] != 0) { flush_n = recv_n; @@ -1552,6 +1553,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, fi_addr_t local_ep_addr; nccl_net_ofi_sendrecv_listen_comm_t *l_comm = NULL; uint64_t tag; + int dev_id = 0; int num_addrs; nccl_net_ofi_sendrecv_ep_t *ep = (nccl_net_ofi_sendrecv_ep_t *)base_ep; @@ -1565,7 +1567,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, goto exit; } - int dev_id = device->base.dev_id; + dev_id = device->base.dev_id; /* Zero-out the handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); @@ -1666,6 +1668,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t ssize_t rc = 0; nccl_net_ofi_sendrecv_req_t *req = NULL; void *desc = NULL; + nccl_net_ofi_sendrecv_device_t *device = NULL; int dev_id = s_comm->base.base.dev_id; struct fid_mr *mr_handle = (struct fid_mr *)mhandle; @@ -1679,8 +1682,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -2114,6 +2116,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, static int release_ep(nccl_net_ofi_ep_t *base_ep) { int ret = 0; + nccl_net_ofi_sendrecv_device_t *device = NULL; /* Validate device */ nccl_net_ofi_sendrecv_ep_t *ep = @@ -2125,8 +2128,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep) } /* Validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -2385,6 +2387,7 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info) { int ret; + bool provide_own_mr_key = true; nccl_net_ofi_sendrecv_device_t *device = (nccl_net_ofi_sendrecv_device_t *)calloc(1, sizeof(nccl_net_ofi_sendrecv_device_t)); @@ -2433,7 +2436,6 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin, } /* Indicates if the provider selects MR keys */ - bool provide_own_mr_key = true; ret = nccl_ofi_mr_keys_need_own_key(info, &provide_own_mr_key); if (ret != 0) { NCCL_OFI_WARN("MR key config parsing failed: %s", diff --git a/src/platform-aws.c b/src/platform-aws.c index 6e7ed5bfc..68ab64962 100644 --- a/src/platform-aws.c +++ b/src/platform-aws.c @@ -604,6 +604,10 @@ int platform_init(const char **provider_filter) int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) { int ret = 0; +#if HAVE_CUDA + const char *optname_name = "none"; + int optname = -1; +#endif if (endpoint == NULL) { NCCL_OFI_WARN("Unable to configure invalid endpoint"); @@ -645,8 +649,6 @@ int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) { static bool nccl_proto_configured = false; static bool need_ordering = false; static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - int optname = -1; - const char *optname_name = "none"; /* During initialization, try to set * FI_OPT_EFA_{SENDRECV,WRTIE}_IN_ORDER_ALIGNED_128_BYTES to @@ -785,6 +787,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails) { struct fi_info *info_list_in = *info_list; struct fi_info **sorted_info_array = (struct fi_info **)alloca(num_rails*sizeof(struct fi_info *)); + struct fi_info *info_ptr = NULL; if (num_rails <= 0) { return; @@ -824,7 +827,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails) /* Update info_list references to match sorted order */ *info_list = sorted_info_array[0]; - struct fi_info *info_ptr = *info_list; + info_ptr = *info_list; for (int i = 0; i < num_rails; ++i) { assert(info_ptr); assert(sorted_info_array[i]);