Skip to content

Commit

Permalink
Add net.cqs
Browse files Browse the repository at this point in the history
  • Loading branch information
Lior Paz committed Dec 29, 2020
1 parent af1d0bb commit 08a6cf7
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 41 deletions.
30 changes: 17 additions & 13 deletions src/team_lib/mhba/xccl_mhba_collective.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <xccl_mhba_collective.h>
#include "utils/utils.h"

#define TMP_TRANSPOSE_PREALLOC 256
#define TMP_TRANSPOSE_PREALLOC 256 //todo check size

xccl_status_t xccl_mhba_collective_init_base(xccl_coll_op_args_t *coll_args,
xccl_mhba_coll_req_t **request,
Expand Down Expand Up @@ -249,7 +249,7 @@ static inline xccl_status_t send_block_data(struct ibv_qp *qp,

static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr,
uint32_t rkey, xccl_mhba_team_t *team,
xccl_mhba_coll_req_t *request)
xccl_mhba_coll_req_t *request, int flags)
{
struct ibv_send_wr *bad_wr;
struct ibv_sge list = {
Expand All @@ -263,7 +263,7 @@ static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr,
.sg_list = &list,
.num_sge = 1,
.opcode = IBV_WR_ATOMIC_FETCH_AND_ADD,
.send_flags = IBV_SEND_SIGNALED,
.send_flags = flags,
.wr.atomic.remote_addr = remote_addr,
.wr.atomic.rkey = rkey,
.wr.atomic.compare_add = 1ULL,
Expand Down Expand Up @@ -316,7 +316,7 @@ static inline xccl_status_t prepost_dummy_recv(struct ibv_qp *qp, int num)
// add polling mechanism for blocks in order to maintain const qp tx rx
static xccl_status_t
xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)
{
{ //todo make non-blocking
xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t);
xccl_mhba_coll_req_t *request = self->req;
xccl_mhba_team_t *team = request->team;
Expand Down Expand Up @@ -386,21 +386,19 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task)
xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k);
return status;
}
while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {}
while (!ibv_poll_cq(team->net.cqs[i], 1, transpose_completion)) {}
}
}
}

for (i = 0; i < net_size; i++) {
status = send_atomic(team->net.qps[i],
(uintptr_t)team->net.remote_ctrl[i].addr +
(index * MHBA_CTRL_SIZE),
team->net.remote_ctrl[i].rkey, team, request);
(index * MHBA_CTRL_SIZE),
team->net.remote_ctrl[i].rkey, team, request,0);
if (status != XCCL_OK) {
xccl_mhba_error("Failed sending atomic to node [%d]", i);
return status;
}
}

xccl_task_enqueue(task->schedule->tl_ctx->pq, task);
return XCCL_OK;
}
Expand Down Expand Up @@ -450,7 +448,7 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)
status = send_atomic(team->net.qps[i],
(uintptr_t)team->net.remote_ctrl[i].addr +
(index * MHBA_CTRL_SIZE),
team->net.remote_ctrl[i].rkey, team, request);
team->net.remote_ctrl[i].rkey, team, request,IBV_SEND_SIGNALED);
if (status != XCCL_OK) {
xccl_mhba_error("Failed sending atomic to node [%d]", i);
return status;
Expand All @@ -460,13 +458,18 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task)
return XCCL_OK;
}

xccl_status_t xccl_mhba_send_blocks_progress_transpose(xccl_coll_task_t *task){
task->state = XCCL_TASK_STATE_COMPLETED;
return XCCL_OK;
}

xccl_status_t xccl_mhba_send_blocks_progress(xccl_coll_task_t *task)
{
xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t);
xccl_mhba_coll_req_t *request = self->req;
xccl_mhba_team_t *team = request->team;
int i, completions_num;
completions_num = ibv_poll_cq(team->net.cq, team->net.sbgp->group_size,
completions_num = ibv_poll_cq(team->net.cqs[0], team->net.sbgp->group_size,
team->work_completion);
if (completions_num < 0) {
xccl_mhba_error("ibv_poll_cq() failed for RDMA_ATOMIC execution");
Expand Down Expand Up @@ -579,11 +582,12 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args,
if (team->transpose) {
request->tasks[2].super.handlers[XCCL_EVENT_COMPLETED] =
xccl_mhba_send_blocks_start_with_transpose;
request->tasks[2].super.progress = xccl_mhba_send_blocks_progress_transpose;
} else {
request->tasks[2].super.handlers[XCCL_EVENT_COMPLETED] =
xccl_mhba_send_blocks_start;
request->tasks[2].super.progress = xccl_mhba_send_blocks_progress;
}
request->tasks[2].super.progress = xccl_mhba_send_blocks_progress;

request->tasks[3].super.handlers[XCCL_EVENT_COMPLETED] =
xccl_mhba_fanout_start;
Expand Down
2 changes: 1 addition & 1 deletion src/team_lib/mhba/xccl_mhba_collective.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ typedef struct xccl_mhba_coll_req {
xccl_mhba_task_t *tasks;
xccl_coll_op_args_t args;
xccl_mhba_team_t *team;
int seq_num;
uint32_t seq_num;
xccl_tl_coll_req_t *barrier_req;
int block_size;
int started;
Expand Down
8 changes: 3 additions & 5 deletions src/team_lib/mhba/xccl_mhba_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ typedef struct xccl_tl_mhba_context_config {
int block_size;
} xccl_tl_mhba_context_config_t;

//todo add block_size config

typedef struct xccl_team_lib_mhba {
xccl_team_lib_t super;
xccl_team_lib_mhba_config_t config;
Expand Down Expand Up @@ -105,7 +103,7 @@ typedef struct xccl_mhba_node {
struct mlx5dv_qp_ex *umr_mlx5dv_qp_ex;
} xccl_mhba_node_t;

#define MHBA_CTRL_SIZE 128 //todo change according to arch
#define MHBA_CTRL_SIZE 128 //todo change to UCS_ARCH_CACHE_LINE_SIZE
#define MHBA_DATA_SIZE sizeof(struct mlx5dv_mr_interleaved)
#define MHBA_NUM_OF_BLOCKS_SIZE_BINS 8
#define MAX_TRANSPOSE_SIZE 8000 // HW transpose unit is limited to matrix size
Expand All @@ -126,7 +124,7 @@ typedef struct xccl_mhba_net {
int net_size;
int *rank_map;
struct ibv_qp **qps;
struct ibv_cq *cq;
struct ibv_cq **cqs;
struct ibv_mr *ctrl_mr;
struct {
void *addr;
Expand All @@ -143,7 +141,7 @@ typedef struct xccl_mhba_team {
uint64_t max_msg_size;
xccl_mhba_node_t node;
xccl_mhba_net_t net;
int sequence_number;
uint32_t sequence_number;
int op_busy[MAX_OUTSTANDING_OPS];
int cq_completions[MAX_OUTSTANDING_OPS];
xccl_mhba_context_t *context;
Expand Down
58 changes: 36 additions & 22 deletions src/team_lib/mhba/xccl_mhba_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
mhba_team->requested_block_size = ctx->cfg.block_size;
if (mhba_team->node.asr_rank == node->group_rank) {
if (mhba_team->transpose) {
mhba_team->transpose_buf = malloc(ctx->cfg.transpose_buf_size);
mhba_team->transpose_buf = malloc(ctx->cfg.transpose_buf_size); //todo malloc per operation for parallel
if (!mhba_team->transpose_buf) {
goto fail_after_shmat;
}
Expand All @@ -279,18 +279,9 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
xccl_mhba_error("Failed to init UMR");
goto fail_after_transpose_reg;
}
asr_cq_size = net_size * MAX_OUTSTANDING_OPS;
mhba_team->net.cq = ibv_create_cq(mhba_team->node.shared_ctx,
asr_cq_size, NULL, NULL, 0);
if (!mhba_team->net.cq) {
xccl_mhba_error("failed to allocate ASR CQ");
goto fail_after_transpose_reg;
}

memset(&qp_init_attr, 0, sizeof(qp_init_attr));
//todo change in case of non-homogenous ppn
qp_init_attr.send_cq = mhba_team->net.cq;
qp_init_attr.recv_cq = mhba_team->net.cq;
qp_init_attr.cap.max_send_wr =
(SQUARED(node_size / 2) + 1) * MAX_OUTSTANDING_OPS; // TODO switch back to fixed tx/rx
qp_init_attr.cap.max_recv_wr =
Expand All @@ -303,15 +294,21 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
mhba_team->net.qps = malloc(sizeof(struct ibv_qp *) * net_size);
if (!mhba_team->net.qps) {
xccl_mhba_error("failed to allocate asr qps array");
goto fail_after_cq;
goto fail_after_transpose_reg;
}
mhba_team->net.cqs = malloc(sizeof(struct ibv_cq *) * (mhba_team->transpose ? net_size : 1));
if (!mhba_team->net.cqs) {
xccl_mhba_error("failed to allocate asr cqs array");
goto fail_after_qp_alloc;
}

// for each ASR - qp num, in addition to port lid, ctrl segment rkey and address, recieve mkey rkey
local_data_size = (net_size * sizeof(uint32_t)) + sizeof(uint32_t) +
2 * sizeof(uint32_t) + sizeof(void *);
local_data = malloc(local_data_size);
if (!local_data) {
xccl_mhba_error("failed to allocate local data");
goto local_data_fail;
goto fail_after_cq_alloc;
}
global_data = malloc(local_data_size * net_size);
if (!global_data) {
Expand All @@ -320,12 +317,24 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
}

for (i = 0; i < net_size; i++) {
if(i == 0 || mhba_team->transpose){
mhba_team->net.cqs[i] = ibv_create_cq(mhba_team->node.shared_ctx, mhba_team->transpose ?
MAX_OUTSTANDING_OPS : net_size * MAX_OUTSTANDING_OPS, NULL, NULL, 0);
if (!mhba_team->net.cqs[i]) {
xccl_mhba_error("failed to create cq for dest %d, errno %d", i,
errno);
goto cq_qp_creation_fail;
}
qp_init_attr.send_cq = mhba_team->net.cqs[i];
qp_init_attr.recv_cq = mhba_team->net.cqs[i];
}

mhba_team->net.qps[i] =
ibv_create_qp(mhba_team->node.shared_pd, &qp_init_attr);
if (!mhba_team->net.qps[i]) {
xccl_mhba_error("failed to create qp for dest %d, errno %d", i,
errno);
goto ctrl_fail;
goto cq_qp_creation_fail;
}
local_data[i] = mhba_team->net.qps[i]->qp_num;
}
Expand All @@ -337,7 +346,7 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_LOCAL_WRITE);
if (!mhba_team->net.ctrl_mr) {
xccl_mhba_error("failed to register control data, errno %d", errno);
goto ctrl_fail;
goto cq_qp_creation_fail;
}
ibv_query_port(ctx->ib_ctx, ctx->ib_port, &port_attr);
local_data[net_size] = port_attr.lid;
Expand Down Expand Up @@ -436,16 +445,18 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context,
ibv_dereg_mr(mhba_team->dummy_bf_mr);
remote_ctrl_fail:
ibv_dereg_mr(mhba_team->net.ctrl_mr);
ctrl_fail:
cq_qp_creation_fail:
free(global_data);
for (i = 0; i < net_size; i++){
ibv_destroy_cq(mhba_team->net.cqs[i]);
ibv_destroy_qp(mhba_team->net.qps[i]);
}
global_data_fail:
free(local_data);
local_data_fail:
fail_after_cq_alloc:
free(mhba_team->net.cqs);
fail_after_qp_alloc:
free(mhba_team->net.qps);
fail_after_cq:
if (ibv_destroy_cq(mhba_team->net.cq)) {
xccl_mhba_error("net cq destroy failed (errno=%d)", errno);
}
fail_after_transpose_reg:
ibv_dereg_mr(mhba_team->transpose_buf_mr);
free(mhba_team->transpose_buf);
Expand Down Expand Up @@ -495,9 +506,12 @@ xccl_status_t xccl_mhba_team_destroy(xccl_tl_team_t *team)
ibv_destroy_qp(mhba_team->net.qps[i]);
}
free(mhba_team->net.qps);
if (ibv_destroy_cq(mhba_team->net.cq)) {
xccl_mhba_error("net cq destroy failed (errno=%d)", errno);
for (i = 0; i < (mhba_team->transpose ? mhba_team->net.sbgp->group_size : 1); i++) {
if (ibv_destroy_cq(mhba_team->net.cqs[i])) {
xccl_mhba_error("net cq destroy failed (errno=%d)", errno);
}
}
free(mhba_team->net.cqs);
mhba_team->net.ucx_team->ctx->lib->team_destroy(
mhba_team->net.ucx_team);

Expand Down

0 comments on commit 08a6cf7

Please sign in to comment.