Skip to content

Commit

Permalink
rdma: Move control messages to own endpoint
Browse files Browse the repository at this point in the history
For the long message RDMA protocol, we want to make sure that we
never starve the sender for data to move, which means prioritizing
control messages from the receiver to the sender.  This patch moves
both the communicator setup and recv control messages to a new
endpoint, which is always on device rail 0.  Future patches will
optimize polling of the control message cq in the send path
and setting priority bits on the control cq.

Signed-off-by: Brian Barrett <[email protected]>
Signed-off-by: Raghu Raja <[email protected]>
  • Loading branch information
rajachan committed Aug 26, 2024
1 parent 658fe1d commit b7a686a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 26 deletions.
11 changes: 10 additions & 1 deletion include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ typedef uint16_t nccl_ofi_rdma_msg_type_t;
* allocate a rdma memory registration handle with `num_rails' rails.
*/
typedef struct nccl_net_ofi_rdma_mr_handle {
struct fid_mr *control_mr;

int num_rails;

/* Array of size `num_rails' */
Expand Down Expand Up @@ -394,13 +396,15 @@ typedef struct nccl_ofi_rdma_connection_info {
* on the receiver side */
uint32_t remote_comm_id;

nccl_ofi_rdma_ep_name_t control_ep_name;

/* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t`
* structs. The member `num_rails` indicates the number of
* entries that are in use. */
nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS];
} nccl_ofi_rdma_connection_info_t;
/* Since this is a message on the wire, check that it has the expected size */
_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 272,
_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 336,
"Wrong size for RDMA connect message");

/*
Expand Down Expand Up @@ -452,6 +456,8 @@ typedef struct nccl_net_ofi_rdma_send_comm {

nccl_ofi_msgbuff_t *msgbuff;

nccl_net_ofi_rdma_send_comm_rail_t control_rail;

/* Number of rails */
int num_rails;

Expand Down Expand Up @@ -534,6 +540,7 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
#if HAVE_NVTX_TRACING
nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM];
#endif
nccl_net_ofi_rdma_recv_comm_rail_t control_rail;

/* Number of rails */
int num_rails;
Expand Down Expand Up @@ -626,6 +633,8 @@ struct nccl_net_ofi_rdma_ep {
* and its base struct. */
nccl_net_ofi_ep_t base;

nccl_net_ofi_ep_rail_t control_rail;

/* Number of rails */
int num_rails;

Expand Down
123 changes: 98 additions & 25 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,11 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep)
}
}

ret = ofi_process_cq_rail(ep, &ep->control_rail);
if (ret != 0) {
goto exit;
}

/* Process any pending requests */
ret = process_pending_reqs(ep);
if (OFI_UNLIKELY(ret != 0 && ret != -FI_EAGAIN)) {
Expand Down Expand Up @@ -2012,6 +2017,12 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep)
}
}

ret = post_bounce_buffs_on_rail(ep, &ep->control_rail);
if (ret != 0) {
NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)");
goto exit;
}

exit:
return ret;
}
Expand Down Expand Up @@ -2304,14 +2315,23 @@ static int prepare_recv_conn_req(nccl_net_ofi_rdma_listen_comm_t *l_comm)
*/
static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle)
{
/* Cleanup memory registration */
int ret = 0;
int rc = 0;
int num_rails = handle->num_rails;

/* Cleanup memory registration for control */
rc = fi_close(&handle->control_mr->fid);
if (OFI_UNLIKELY(rc != 0)) {
NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s",
rc, fi_strerror(-rc));
ret = rc;
}

/* Cleanup memory registration for data rails */
for (int rail_id = 0; rail_id != num_rails; ++rail_id) {
/* No memory registration available for this rail */
if (!handle->mr[rail_id]) continue;
int rc = fi_close(&handle->mr[rail_id]->fid);
rc = fi_close(&handle->mr[rail_id]->fid);
if (OFI_UNLIKELY(rc != 0)) {
NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s",
rc, fi_strerror(-rc));
Expand Down Expand Up @@ -2361,7 +2381,7 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle,
return -EINVAL;
}

if (OFI_UNLIKELY(mr_handle->num_rails < 1)) {
if (OFI_UNLIKELY(mr_handle->num_rails < 0)) {
NCCL_OFI_WARN("Unexpected number of rails in rdma memory registration handle");
return -EINVAL;
}
Expand Down Expand Up @@ -2444,6 +2464,15 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
goto exit;
}

ret = register_rail_mr_buffer(get_device_rail(device, 0)->domain, ep->control_rail.ofi_ep,
-1, type, &mr_attr,
&ret_handle->control_mr);
if (OFI_UNLIKELY(ret != 0)) {
free(ret_handle);
ret_handle = NULL;
goto exit;
}

/* Register memory on each rail */
ret_handle->num_rails = num_rails;
for (int rail_id = 0; rail_id != num_rails; ++rail_id) {
Expand Down Expand Up @@ -2743,7 +2772,8 @@ static inline nccl_net_ofi_rdma_req_t *allocate_req(nccl_ofi_freelist_t *fl)
}

/**
* @brief Allocate a new send ctrl req from freelist
* @brief Allocate a new control message that the receiver will
* send to the sender describing the recv buffer.
*/
static inline int insert_send_ctrl_req(
nccl_net_ofi_rdma_recv_comm_t *r_comm,
Expand Down Expand Up @@ -3481,7 +3511,8 @@ static inline nccl_net_ofi_rdma_recv_comm_t *calloc_rdma_recv_comm(int num_rails
* @return Receive communicator object, on success
* NULL, on error
*/
static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device_t *device,
static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen_comm_t *l_comm,
nccl_net_ofi_rdma_device_t *device,
nccl_net_ofi_rdma_ep_t *l_comm_ep,
nccl_ofi_rdma_connection_info_t *conn_msg)
{
Expand Down Expand Up @@ -3576,6 +3607,25 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device
/* Add ourselves to ep's lookup array */
set_comm(device, r_comm->local_comm_id, &r_comm->base.base);

r_comm->control_rail.local_ep = l_comm->leader_local_ep;
ret = fi_av_insert(ep->control_rail.av, (void *)conn_msg->control_ep_name.ep_name, 1,
&r_comm->control_rail.remote_addr, 0, NULL);
if (OFI_UNLIKELY(ret != 1)) {
NCCL_OFI_WARN("Unable to insert remote address into address vector "
"for device %d. RC: %d",
dev_id, fi_strerror(-ret));
goto error;
}

ret = fi_av_insert(ep->control_rail.av, (void *)ep->control_rail.local_ep_name, 1,
&r_comm->control_rail.local_addr, 0, NULL);
if (OFI_UNLIKELY(ret != 1)) {
NCCL_OFI_WARN("Unable to insert local address into address vector "
"for device %d. RC: %d",
dev_id, fi_strerror(-ret));
goto error;
}

/* Allocate array of communicator rails */
r_comm->num_rails = num_rails;

Expand Down Expand Up @@ -3748,7 +3798,7 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm,
nccl_net_ofi_rdma_req_t *req)
{
ssize_t rc = 0;
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = get_recv_comm_rail(r_comm, 0);
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;;

req->state = NCCL_OFI_RDMA_REQ_PENDING;
rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL,
Expand Down Expand Up @@ -3888,7 +3938,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm,
}

/* Prepare receive communicator object for the received peer connection */
r_comm = prepare_recv_comm(device, l_comm_ep, conn_msg);
r_comm = prepare_recv_comm(l_comm, device, l_comm_ep, conn_msg);
if (OFI_UNLIKELY(r_comm == NULL)) {
ret = -EINVAL;
goto exit;
Expand Down Expand Up @@ -4042,7 +4092,6 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL;
nccl_net_ofi_rdma_ep_t *ep =
(nccl_net_ofi_rdma_ep_t *)base_ep;
nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0);

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device;
Expand All @@ -4052,14 +4101,14 @@ static int listen(nccl_net_ofi_ep_t *base_ep,

/* Build handle */
memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t));
assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name));
memcpy(handle->ep_name, first_rail->local_ep_name,
first_rail->local_ep_name_len);
assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name));
memcpy(handle->ep_name, ep->control_rail.local_ep_name,
ep->control_rail.local_ep_name_len);
/* We don't copy the size here since the handle doesn't have a size field.
The size will be distributed later by the connect response message.
Instead, zero the unused bytes here. */
memset(handle->ep_name + first_rail->local_ep_name_len, 0,
sizeof(handle->ep_name) - first_rail->local_ep_name_len);
memset(handle->ep_name + ep->control_rail.local_ep_name_len, 0,
sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len);

/* Build listen_comm */
l_comm = (nccl_net_ofi_rdma_listen_comm_t *)calloc(1,
Expand All @@ -4076,7 +4125,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
l_comm->base.base.dev_id = dev_id;
l_comm->base.accept = accept;
l_comm->base.close = listen_close;
l_comm->leader_local_ep = first_rail->ofi_ep;
l_comm->leader_local_ep = ep->control_rail.ofi_ep;

/* Allocate listen communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool);
Expand Down Expand Up @@ -4381,11 +4430,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req)
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm;
rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req);
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
const int control_rail_id = 0;

// Get communicator rail information to xfer the req
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail;
comm_rail = get_recv_comm_rail(r_comm, control_rail_id);
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;

nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item;

Expand All @@ -4394,7 +4441,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req)
(freelist_regmr_fn_handle_t *)ctrl_fl_item->fl_reginfo.mr_handle;
nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle;

void *desc = fi_mr_desc(mr_handle->mr[control_rail_id]);
void *desc = fi_mr_desc(mr_handle->control_mr);

NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, xfer_info->rail_id, req->comm, req, req->msg_seq_num);

Expand Down Expand Up @@ -4808,6 +4855,14 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id,
/* Set number of rails to be sent back to remote for verification */
conn_msg->num_rails = num_rails;

/* Set libfabric endpoint name for control rail */
memcpy(conn_msg->control_ep_name.ep_name,
ep->control_rail.local_ep_name,
ep->control_rail.local_ep_name_len);
conn_msg->control_ep_name.ep_name_len =
ep->control_rail.local_ep_name_len;


/* Set libfabric endpoint names for each rail */
for (int rail_id = 0; rail_id != num_rails; ++rail_id) {
memcpy(conn_msg->ep_names[rail_id].ep_name,
Expand Down Expand Up @@ -4866,6 +4921,14 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep)
return ret;
}

ep->control_rail.min_bounce_posted = NCCL_OFI_DIV_CEIL(
ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails
);
ep->control_rail.max_bounce_posted = NCCL_OFI_DIV_CEIL(
ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails
);
ret = pthread_mutex_init(&ep->control_rail.bounce_mutex, NULL);

for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) {
nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id);
rail->min_bounce_posted = NCCL_OFI_DIV_CEIL(
Expand Down Expand Up @@ -4941,7 +5004,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL;
int num_rails = ep->num_rails;
int rail_id = 0;
nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0);
nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail;
*s_comm = NULL;

/* Retrieve and validate device */
Expand Down Expand Up @@ -4994,7 +5057,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
ret_s_comm->num_rails = num_rails;

/* Insert remote name into AV of first rail */
ret = fi_av_insert(first_rail->av,
ret = fi_av_insert(control_rail->av,
(void *)handle->ep_name, 1,
&remote_addr, 0, NULL);
if (OFI_UNLIKELY(ret != 1)) {
Expand All @@ -5005,11 +5068,11 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
}

/* Store remote address of first rail in communicator */
ret_s_comm->rails[0].remote_addr = remote_addr;
ret_s_comm->control_rail.remote_addr = remote_addr;

/* Store local libfabric endpoint of first rail */
ret_s_comm->rails[0].local_ep = first_rail->ofi_ep;
ret_s_comm->num_init_rails = 1;
/* Store local libfabric endpoint of control rail */
ret_s_comm->control_rail.local_ep = control_rail->ofi_ep;
ret_s_comm->num_init_rails = 0;

/* Allocate request free list */
ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16,
Expand Down Expand Up @@ -5127,7 +5190,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm,
nccl_net_ofi_rdma_req_t *req)
{
ssize_t rc = 0;
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0);
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail;

/*
* TODO: replace it with API of FI_INJECT type when most of
Expand Down Expand Up @@ -5349,6 +5412,7 @@ static void ep_rail_release(nccl_net_ofi_ep_rail_t *rail, int dev_id)
*/
static void release_rdma_ep_resources(nccl_net_ofi_rdma_ep_t *ep, int dev_id)
{
ep_rail_release(&ep->control_rail, dev_id);
for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) {
ep_rail_release(get_rail(ep, rail_id), dev_id);
}
Expand Down Expand Up @@ -5584,6 +5648,15 @@ static int create_ep(nccl_net_ofi_rdma_device_t *device,

ep->use_long_rkeys = device->use_long_rkeys;

/* we pass 0 as the railid for the control rail, so
* that any lookups based on railid in the domain find
* the right domain */
ret = ep_rail_init(ep, device->base.dev_id, 0, &device->device_rails[0], &ep->control_rail);
if (ret != 0) {
NCCL_OFI_WARN("Initializing control rail failed");
goto error;
}

ret = init_rail_ofi_resources(device, ep);
if (ret != 0) {
goto error;
Expand Down

0 comments on commit b7a686a

Please sign in to comment.