Skip to content

Commit

Permalink
rdma: Use get_device_from_ep() accessor
Browse files Browse the repository at this point in the history
Clean up the code to always use the get_device_from_ep() accessor
function instead of poking through data structures.

Signed-off-by: Brian Barrett <[email protected]>
  • Loading branch information
bwbarrett authored and aws-nslick committed Sep 30, 2024
1 parent 3811ebe commit 062eab3
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm_t *s_
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
assert(ep != NULL);

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
nccl_net_ofi_scheduler_t *scheduler = device->scheduler;

rdma_req_send_data_t *send_data = get_send_data(req);
Expand Down Expand Up @@ -2078,7 +2078,7 @@ static inline nccl_net_ofi_rdma_req_t *alloc_bounce_req(nccl_net_ofi_rdma_ep_t *

req->comm = NULL;
req->type = NCCL_OFI_RDMA_BOUNCE;
req->dev_id = ep->base.device->dev_id;
req->dev_id = get_device_from_ep(ep)->base.dev_id;
req->free = free_bounce_req;

rdma_req_bounce_data_t *bounce_data = get_bounce_data(req);
Expand Down Expand Up @@ -2298,7 +2298,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm)
}

/* Retrieve and validate device */
device = (nccl_net_ofi_rdma_device_t*)ep->base.device;
device = get_device_from_ep(ep);
if (OFI_UNLIKELY(device == NULL)) {
NCCL_OFI_WARN("Invalid device provided");
return -EINVAL;
Expand Down Expand Up @@ -2598,8 +2598,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
uint64_t regattr_flags = 0;

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device =
(nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

int dev_id = device->base.dev_id;
Expand Down Expand Up @@ -2684,7 +2683,7 @@ static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep,
assert(ep);

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

nccl_ofi_idpool_t *key_pool = &device->key_pool;
Expand Down Expand Up @@ -2802,7 +2801,7 @@ static int reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm,
int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)send_comm->base.ep;
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

return reg_mr_ep(ep,
Expand All @@ -2817,7 +2816,7 @@ static int reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm,
int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)recv_comm->base.ep;
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

return reg_mr_ep(ep,
Expand Down Expand Up @@ -2862,7 +2861,7 @@ static int freelist_regmr_host_fn(void *ep_void_ptr, void *data, size_t size, vo
}

freelist_handle->mr_handle = mr_handle;
freelist_handle->key_pool = &((nccl_net_ofi_rdma_device_t *)ep->base.device)->key_pool;
freelist_handle->key_pool = &(get_device_from_ep(ep))->key_pool;
*handle = (void *)freelist_handle;
return 0;
}
Expand Down Expand Up @@ -2893,7 +2892,7 @@ static int dereg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm,
assert(ep != NULL);

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle;
Expand Down Expand Up @@ -3187,7 +3186,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
nccl_net_ofi_rdma_ep_t * ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

ret = process_cq_if_pending(ep);
Expand Down Expand Up @@ -3832,7 +3831,7 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);

nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
Expand Down Expand Up @@ -4486,7 +4485,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm,
}

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)l_comm_ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(l_comm_ep);
assert(device != NULL);

int dev_id = device->base.dev_id;
Expand Down Expand Up @@ -4710,7 +4709,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
(nccl_net_ofi_rdma_ep_t *)base_ep;

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

int dev_id = device->base.dev_id;
Expand Down Expand Up @@ -4784,7 +4783,7 @@ static int dereg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm,
assert(ep != NULL);

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
assert(device != NULL);

nccl_net_ofi_rdma_mr_handle_t *mr_handle =
Expand Down Expand Up @@ -4826,7 +4825,7 @@ static int alloc_rdma_send_req(nccl_net_ofi_rdma_send_comm_t *s_comm,
nccl_net_ofi_rdma_req_t **ret_req)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
*ret_req = NULL;

Expand Down Expand Up @@ -5869,7 +5868,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
*s_comm = NULL;

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
if (OFI_UNLIKELY(device == NULL)) {
NCCL_OFI_WARN("Error accessing device");
return -EINVAL;
Expand Down Expand Up @@ -6426,8 +6425,7 @@ static int release_ep(nccl_net_ofi_ep_t *base_ep)
}

/* Validate device */
nccl_net_ofi_rdma_device_t *device =
(nccl_net_ofi_rdma_device_t*)ep->base.device;
nccl_net_ofi_rdma_device_t *device = get_device_from_ep(ep);
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
Expand Down

0 comments on commit 062eab3

Please sign in to comment.