From 34c49d967d035f1a6e469e03c3e54efd2ac6d211 Mon Sep 17 00:00:00 2001 From: Nicholas Sielicki Date: Mon, 12 Aug 2024 12:51:42 -0700 Subject: [PATCH] fix(tree): move declarations to top of function (#563) c++ suppoorts initializers anywhere in the function, but one must not jump over an initializer with any goto usage. Given the lack of RAII in C, this becomes a significant painpoint. In large to-be-eventually-refactored functions contain gotos or use switch statements, split declaration and initialization, and move all declarations to the top of the function. This makes switch statements and gotos safe in both languages. Signed-off-by: Nicholas Sielicki --- src/nccl_ofi_api.c | 25 ++++----- src/nccl_ofi_ep_addr_list.c | 4 +- src/nccl_ofi_net.c | 5 +- src/nccl_ofi_rdma.c | 102 +++++++++++++++++++++++------------- src/nccl_ofi_sendrecv.c | 34 ++++++------ src/platform-aws.c | 9 ++-- 6 files changed, 109 insertions(+), 70 deletions(-) diff --git a/src/nccl_ofi_api.c b/src/nccl_ofi_api.c index 5e5cc5275..957ab295c 100644 --- a/src/nccl_ofi_api.c +++ b/src/nccl_ofi_api.c @@ -365,15 +365,16 @@ ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size, const nccl_ofi_mr_ckey_t cache_key = nccl_ofi_mr_ckey_mk_vec(data, size); #endif + nccl_net_ofi_send_comm_t *send_comm = NULL; + nccl_net_ofi_recv_comm_t *recv_comm = NULL; + switch (base_comm->type) { - case NCCL_NET_OFI_SEND_COMM:; - nccl_net_ofi_send_comm_t *send_comm = - (nccl_net_ofi_send_comm_t *)base_comm; + case NCCL_NET_OFI_SEND_COMM: + send_comm = (nccl_net_ofi_send_comm_t *)base_comm; ret = send_comm->regMr(send_comm, &cache_key, type, mhandle); break; - case NCCL_NET_OFI_RECV_COMM:; - nccl_net_ofi_recv_comm_t *recv_comm = - (nccl_net_ofi_recv_comm_t *)base_comm; + case NCCL_NET_OFI_RECV_COMM: + recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm; ret = recv_comm->regMr(recv_comm, &cache_key, type, mhandle); break; default: @@ -397,16 +398,16 @@ ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle) } int ret = 0; + nccl_net_ofi_send_comm_t *send_comm = NULL; + nccl_net_ofi_recv_comm_t *recv_comm = NULL; switch (base_comm->type) { - case NCCL_NET_OFI_SEND_COMM:; - nccl_net_ofi_send_comm_t *send_comm = - (nccl_net_ofi_send_comm_t *)base_comm; + case NCCL_NET_OFI_SEND_COMM: + send_comm = (nccl_net_ofi_send_comm_t *)base_comm; ret = send_comm->deregMr(send_comm, (nccl_net_ofi_mr_handle_t *)mhandle); break; - case NCCL_NET_OFI_RECV_COMM:; - nccl_net_ofi_recv_comm_t *recv_comm = - (nccl_net_ofi_recv_comm_t *)base_comm; + case NCCL_NET_OFI_RECV_COMM: + recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm; ret = recv_comm->deregMr(recv_comm, (nccl_net_ofi_mr_handle_t *)mhandle); break; default: diff --git a/src/nccl_ofi_ep_addr_list.c b/src/nccl_ofi_ep_addr_list.c index 0dda529d0..ba43bc355 100644 --- a/src/nccl_ofi_ep_addr_list.c +++ b/src/nccl_ofi_ep_addr_list.c @@ -134,6 +134,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list, size_t addr_size) { int ret = 0; + ep_pair_list_elem_t *new_pair = NULL; if (addr_size > ep_list->max_addr_size) { NCCL_OFI_WARN("Address size %zu > max size (%zu)", addr_size, @@ -155,8 +156,7 @@ int nccl_ofi_ep_addr_list_insert(nccl_ofi_ep_addr_list_t *ep_list, memcpy(new_addr->addr, addr_in, addr_size); zero_pad_address(new_addr->addr, addr_size, ep_list->max_addr_size); - ep_pair_list_elem_t *new_pair = (ep_pair_list_elem_t *) - malloc(sizeof(*new_pair)); + new_pair = (ep_pair_list_elem_t *)malloc(sizeof(*new_pair)); if (!new_pair) { NCCL_OFI_WARN("Failed to allocate new ep list element"); free(new_addr); diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 891e2ee94..af13326bc 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -136,6 +136,8 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) int ret = 0; const char *provider_filter = NULL; nccl_net_ofi_plugin_t *plugin; + nccl_net_ofi_ep_t *base_ep = NULL; + nccl_net_ofi_device_t *device = NULL; NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Initializing " PACKAGE_STRING); @@ -275,8 +277,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) * resources. This initialization happens once per process, and thus it * does not matter which device is used to create the endpoint. */ - nccl_net_ofi_device_t *device = plugin->get_device(plugin, 0); - nccl_net_ofi_ep_t *base_ep = NULL; + device = plugin->get_device(plugin, 0); ret = device->get_ep(device, &base_ep); if (ret != 0) { diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index b4dceca8e..712da9717 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -3164,7 +3164,13 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int ret = 0; nccl_net_ofi_rdma_req_t *req = NULL; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; + rdma_req_recv_data_t *recv_data = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; + int dev_id = 0; nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles; + uint16_t msg_seq_num = 0; + bool eager = false; assert(r_comm != NULL); @@ -3181,12 +3187,12 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - int dev_id = r_comm->base.base.dev_id; + dev_id = r_comm->base.base.dev_id; - nccl_net_ofi_rdma_ep_t * ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); assert(device != NULL); ret = process_cq_if_pending(ep); @@ -3195,13 +3201,14 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, *base_req = NULL; ret = 0; goto error; - } else if (ret != 0) { + } + if (ret != 0) { goto error; } - uint16_t msg_seq_num = r_comm->next_msg_seq_num; + msg_seq_num = r_comm->next_msg_seq_num; - bool eager = false; + eager = false; void *elem; nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t msg_stat; @@ -3240,7 +3247,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - rdma_req_recv_data_t *recv_data = get_recv_data(req); + recv_data = get_recv_data(req); if (eager) { nccl_net_ofi_rdma_req_t *bounce_req = (nccl_net_ofi_rdma_req_t *)elem; @@ -3427,6 +3434,7 @@ static int alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_comm, int d static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) { + nccl_net_ofi_rdma_device_t *device = NULL; int ret = 0; /* Retrieve and validate endpoint */ @@ -3437,7 +3445,7 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) return ret; } - nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)base_ep->device; + device = (nccl_net_ofi_rdma_device_t *)base_ep->device; if (r_comm->send_close_req != NULL) { ret = r_comm->send_close_req->free(r_comm->send_close_req, false); @@ -3810,6 +3818,12 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, nccl_net_ofi_req_t **base_req) { int ret = 0; + int flush_n = 0; + bool network_busy = false; + int dev_id = 0; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; + nccl_net_ofi_scheduler_t *scheduler = NULL; nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; @@ -3826,19 +3840,19 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, goto error; } - int dev_id = recv_comm->base.dev_id; + dev_id = recv_comm->base.dev_id; - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); assert(device != NULL); - nccl_net_ofi_scheduler_t *scheduler = device->scheduler; + scheduler = device->scheduler; assert(scheduler != NULL); /* Process any pending requests */ - bool network_busy = false; + network_busy = false; rc = process_cq_if_pending(ep); if (rc == -EAGAIN) { /* Network is still busy. */ @@ -3868,7 +3882,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, * Find the non-zero request for which we will issue flush. * A single operation can flush all request at once. */ - int flush_n = -1; + flush_n = -1; for (int recv_n = 0; recv_n < n; recv_n++) { if (sizes[recv_n] != 0) { flush_n = recv_n; @@ -4032,6 +4046,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; nccl_net_ofi_rdma_req_t *req = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; assert(r_comm != NULL); /* Support only NCCL_OFI_MAX_REQUESTS inflight requests. */ @@ -4042,7 +4057,7 @@ static int rma_read(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size goto error; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -4114,7 +4129,9 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen { int ret = 0; + int comm_id = 0; nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; int dev_id = device->base.dev_id; int num_rails = l_comm_ep->num_rails; @@ -4153,7 +4170,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen r_comm->n_ctrl_delivered = 0; /* Allocate recv communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { r_comm->local_comm_id = ~0; goto error; @@ -4212,7 +4229,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen r_comm->base.base.ep = &l_comm_ep->base; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; /* Add ourselves to ep's lookup array */ set_comm(device, r_comm->local_comm_id, &r_comm->base.base); @@ -4705,6 +4722,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, { int ret = 0; nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL; + int comm_id = 0; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; @@ -4743,7 +4761,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->leader_local_ep = ep->control_rail.ofi_ep; /* Allocate listen communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { l_comm->comm_id = ~0; ret = comm_id; @@ -5339,10 +5357,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t int ret = 0; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; + nccl_net_ofi_rdma_ep_t *ep = NULL; nccl_net_ofi_rdma_req_t *req = NULL; uint16_t msg_seq_num = s_comm->next_msg_seq_num; bool polled_cq = false; bool have_ctrl = false; + bool eager = false; + int dev_id = 0; assert(s_comm != NULL); @@ -5360,9 +5381,9 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } - int dev_id = s_comm->base.base.dev_id; + dev_id = s_comm->base.base.dev_id; - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -5371,7 +5392,8 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t *base_req = NULL; ret = 0; goto error; - } else if (ret != 0) { + } + if (ret != 0) { goto error; } @@ -5379,6 +5401,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t * TODO: Use NCCL provided tags when using grouped receives aka * props->maxRecvs > 1. */ + + have_ctrl = false; + msg_seq_num = s_comm->next_msg_seq_num; + void *elem; nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t msg_stat; @@ -5433,7 +5459,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t } /* Determine if this should be sent eagerly. */ - bool eager = false; + eager = false; if ((!have_ctrl && size <= eager_max_size) || (size == 0)) { eager = true; @@ -5741,6 +5767,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t int ret = 0; nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm; nccl_net_ofi_rdma_req_t *req = NULL; + nccl_net_ofi_rdma_ep_t *ep = NULL; assert(s_comm != NULL); @@ -5752,7 +5779,7 @@ static int rma_write_impl(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t goto error; } - nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; + ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; assert(ep != NULL); ret = process_cq_if_pending(ep); @@ -5860,6 +5887,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t **s_comm) { int ret = 0; + int comm_id = 0; fi_addr_t remote_addr; nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; @@ -5916,7 +5944,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->remote_comm_id = handle->comm_id; /* Allocate send communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); + comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { ret_s_comm->local_comm_id = ~0; ret = comm_id; @@ -6111,6 +6139,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, nccl_net_ofi_send_comm_t **send_comm) { int ret = 0; + nccl_net_ofi_rdma_req_state_t conn_resp_req_state; nccl_net_ofi_rdma_req_state_t conn_msg_state; *send_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = @@ -6238,7 +6267,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, } nccl_net_ofi_mutex_lock(&s_comm->conn_resp_req->req_lock); - nccl_net_ofi_rdma_req_state_t conn_resp_req_state = s_comm->conn_resp_req->state; + conn_resp_req_state = s_comm->conn_resp_req->state; nccl_net_ofi_mutex_unlock(&s_comm->conn_resp_req->req_lock); /* Wait until conn resp message is received */ @@ -6414,10 +6443,11 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device, static int release_ep(nccl_net_ofi_ep_t *base_ep) { int ret = 0; + nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; /* Validate device */ - nccl_net_ofi_rdma_ep_t *ep = - (nccl_net_ofi_rdma_ep_t*)base_ep; + ep = (nccl_net_ofi_rdma_ep_t *)base_ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -6425,7 +6455,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep) } /* Validate device */ - nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep); + device = get_device_from_ep(ep); if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -6605,10 +6635,10 @@ static inline int get_ep(nccl_net_ofi_device_t *base_dev, nccl_net_ofi_ep_t **ba int ret = 0; long thread_id; nccl_net_ofi_rdma_ep_t *ep = NULL; + nccl_net_ofi_rdma_device_t *device = NULL; /* Retrieve and validate device */ - nccl_net_ofi_rdma_device_t *device = - (nccl_net_ofi_rdma_device_t*)base_dev; + device = (nccl_net_ofi_rdma_device_t *)base_dev; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -6929,8 +6959,9 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info_list, nccl_ofi_topo_t *topo, size_t rr_threshold) { - int ret; - + int ret = 0; + bool provide_own_mr_key = false; + int length = 0; nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)calloc(1, sizeof(nccl_net_ofi_rdma_device_t)); if (device == NULL) { @@ -6961,7 +6992,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, } /* Ensure that number of rails are the same across devices */ - int length = ofi_info_list_length(info_list); + length = ofi_info_list_length(info_list); if (topo->max_group_size != length) { NCCL_OFI_WARN("Wrong number of NICs for device %i. Expected %i but got %i", dev_id, topo->max_group_size, length); @@ -7038,7 +7069,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, } /* Initialize mr key pool */ - bool provide_own_mr_key = true; + provide_own_mr_key = true; ret = nccl_ofi_mr_keys_need_own_key(info_list, &provide_own_mr_key); if (ret != 0) { NCCL_OFI_WARN("MR key config parsing failed: %s", @@ -7259,6 +7290,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, nccl_net_ofi_rdma_plugin_t *plugin = NULL; nccl_ofi_topo_t *topo = NULL; struct fi_info *hints; + uint32_t api_version = 0; *found_multiple_rails = false; @@ -7270,7 +7302,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } get_hints(hints); - uint32_t api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18); + api_version = nccl_ofi_dmabuf_viable() ? FI_VERSION(1, 20) : FI_VERSION(1, 18); ret = nccl_ofi_ofiutils_get_providers(provider_filter, api_version, hints, &provider_list, &num_providers); if (ret == 0) { diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 795dbafad..e97dec021 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -360,6 +360,8 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) { int ret = 0; nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)base_req; + nccl_net_ofi_sendrecv_device_t *device = NULL; + nccl_net_ofi_sendrecv_ep_t *ep = NULL; /* Retrieve and validate comm */ nccl_net_ofi_comm_t *base_comm = req->comm; @@ -370,8 +372,7 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } /* Retrieve and validate endpoint */ - nccl_net_ofi_sendrecv_ep_t *ep = - (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; + ep = (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -379,8 +380,7 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -760,6 +760,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, /* Retrieve and validate endpoint */ nccl_net_ofi_sendrecv_ep_t *ep = (nccl_net_ofi_sendrecv_ep_t *)base_comm->ep; + nccl_ofi_idpool_t *key_pool = NULL; if (OFI_UNLIKELY(ep == NULL)) { NCCL_OFI_WARN("Invalid endpoint provided"); return -EINVAL; @@ -795,7 +796,7 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, } /* Cache miss */ - nccl_ofi_idpool_t *key_pool = &device->key_pool; + key_pool = &device->key_pool; struct fid_domain *domain; domain = get_domain_from_endpoint(ep); ret = reg_mr_base(domain, ep->ofi_ep, key_pool, @@ -889,14 +890,15 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, int ret = 0; ssize_t rc = 0; nccl_net_ofi_sendrecv_req_t *req = NULL; + nccl_net_ofi_sendrecv_ep_t *ep = NULL; + nccl_net_ofi_sendrecv_device_t *device = NULL; nccl_net_ofi_sendrecv_recv_comm_t *r_comm = (nccl_net_ofi_sendrecv_recv_comm_t *)recv_comm; int dev_id = r_comm->base.base.dev_id; struct fid_mr **mr_handles = (struct fid_mr **)mhandles; /* Retrieve and validate endpoint */ - nccl_net_ofi_sendrecv_ep_t * ep = - (nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep; + ep = (nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep; if (OFI_UNLIKELY(ep == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid endpoint provided"); @@ -904,8 +906,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -1051,6 +1052,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, void *data = NULL; void *flush_mr_desc = NULL; int dev_id = recv_comm->base.dev_id; + int flush_n = -1; struct fid_mr **mr_handles = (struct fid_mr **)mhandles; if (ofi_nccl_gdr_flush_disable() || support_gdr == GDR_UNSUPPORTED) @@ -1073,7 +1075,6 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, * Find the non-zero request for which we will issue flush. * A single operation can flush all request at once. */ - int flush_n = -1; for (int recv_n = 0; recv_n < n; recv_n++) { if (sizes[recv_n] != 0) { flush_n = recv_n; @@ -1552,6 +1553,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, fi_addr_t local_ep_addr; nccl_net_ofi_sendrecv_listen_comm_t *l_comm = NULL; uint64_t tag; + int dev_id = 0; int num_addrs; nccl_net_ofi_sendrecv_ep_t *ep = (nccl_net_ofi_sendrecv_ep_t *)base_ep; @@ -1565,7 +1567,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, goto exit; } - int dev_id = device->base.dev_id; + dev_id = device->base.dev_id; /* Zero-out the handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); @@ -1666,6 +1668,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t ssize_t rc = 0; nccl_net_ofi_sendrecv_req_t *req = NULL; void *desc = NULL; + nccl_net_ofi_sendrecv_device_t *device = NULL; int dev_id = s_comm->base.base.dev_id; struct fid_mr *mr_handle = (struct fid_mr *)mhandle; @@ -1679,8 +1682,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t } /* Retrieve and validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -2114,6 +2116,7 @@ static int connect(nccl_net_ofi_ep_t *base_ep, static int release_ep(nccl_net_ofi_ep_t *base_ep) { int ret = 0; + nccl_net_ofi_sendrecv_device_t *device = NULL; /* Validate device */ nccl_net_ofi_sendrecv_ep_t *ep = @@ -2125,8 +2128,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep) } /* Validate device */ - nccl_net_ofi_sendrecv_device_t *device = - (nccl_net_ofi_sendrecv_device_t*)ep->base.device; + device = (nccl_net_ofi_sendrecv_device_t *)ep->base.device; if (OFI_UNLIKELY(device == NULL)) { ret = -EINVAL; NCCL_OFI_WARN("Invalid device provided"); @@ -2385,6 +2387,7 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info) { int ret; + bool provide_own_mr_key = true; nccl_net_ofi_sendrecv_device_t *device = (nccl_net_ofi_sendrecv_device_t *)calloc(1, sizeof(nccl_net_ofi_sendrecv_device_t)); @@ -2433,7 +2436,6 @@ nccl_net_ofi_sendrecv_device_create(nccl_net_ofi_plugin_t *plugin, } /* Indicates if the provider selects MR keys */ - bool provide_own_mr_key = true; ret = nccl_ofi_mr_keys_need_own_key(info, &provide_own_mr_key); if (ret != 0) { NCCL_OFI_WARN("MR key config parsing failed: %s", diff --git a/src/platform-aws.c b/src/platform-aws.c index 6e7ed5bfc..68ab64962 100644 --- a/src/platform-aws.c +++ b/src/platform-aws.c @@ -604,6 +604,10 @@ int platform_init(const char **provider_filter) int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) { int ret = 0; +#if HAVE_CUDA + const char *optname_name = "none"; + int optname = -1; +#endif if (endpoint == NULL) { NCCL_OFI_WARN("Unable to configure invalid endpoint"); @@ -645,8 +649,6 @@ int platform_config_endpoint(struct fi_info *info, struct fid_ep* endpoint) { static bool nccl_proto_configured = false; static bool need_ordering = false; static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - int optname = -1; - const char *optname_name = "none"; /* During initialization, try to set * FI_OPT_EFA_{SENDRECV,WRTIE}_IN_ORDER_ALIGNED_128_BYTES to @@ -785,6 +787,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails) { struct fi_info *info_list_in = *info_list; struct fi_info **sorted_info_array = (struct fi_info **)alloca(num_rails*sizeof(struct fi_info *)); + struct fi_info *info_ptr = NULL; if (num_rails <= 0) { return; @@ -824,7 +827,7 @@ void platform_sort_rails(struct fi_info **info_list, int num_rails) /* Update info_list references to match sorted order */ *info_list = sorted_info_array[0]; - struct fi_info *info_ptr = *info_list; + info_ptr = *info_list; for (int i = 0; i < num_rails; ++i) { assert(info_ptr); assert(sorted_info_array[i]);