diff --git a/src/team_lib/mhba/xccl_mhba_collective.c b/src/team_lib/mhba/xccl_mhba_collective.c index 85d67e0..6927f8e 100644 --- a/src/team_lib/mhba/xccl_mhba_collective.c +++ b/src/team_lib/mhba/xccl_mhba_collective.c @@ -153,7 +153,7 @@ static xccl_status_t xccl_mhba_fanout_start(xccl_coll_task_t *task) /* start task if completion event received */ task->state = XCCL_TASK_STATE_INPROGRESS; - /* Start fanin */ + /* Start fanout */ if (XCCL_OK == xccl_mhba_node_fanout(team, request)) { task->state = XCCL_TASK_STATE_COMPLETED; xccl_mhba_debug("Algorithm completion"); @@ -179,12 +179,43 @@ static xccl_status_t xccl_mhba_fanout_progress(xccl_coll_task_t *task) return XCCL_OK; } +static inline xccl_status_t send_block_data(struct ibv_qp *qp, + uint64_t src_addr, + uint32_t msg_size, uint32_t lkey, + uint64_t remote_addr, uint32_t rkey, + int send_flags, int with_imm) +{ + struct ibv_send_wr *bad_wr; + struct ibv_sge list = { + .addr = src_addr, + .length = msg_size, + .lkey = lkey, + }; + + struct ibv_send_wr wr = { + .wr_id = 1, + .sg_list = &list, + .num_sge = 1, + .opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE, + .send_flags = send_flags, + .wr.rdma.remote_addr = remote_addr, + .wr.rdma.rkey = rkey, + }; + + if (ibv_post_send(qp, &wr, &bad_wr)) { + xccl_mhba_error("failed to post send"); + return XCCL_ERR_NO_MESSAGE; + } + return XCCL_OK; +} + static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task) { xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t); xccl_mhba_coll_req_t *request = self->req; xccl_mhba_team_t *team = request->team; xccl_mhba_debug("asr barrier start"); + int i; if(request->buff_change_flag) { // despite while statement, non blocking because have independent cq, will be finished in a finite time @@ -192,65 +223,31 @@ static xccl_status_t xccl_mhba_asr_barrier_start(xccl_coll_task_t *task) } //Reset atomic notification counter to 0 - memset(team->node.storage + MHBA_CTRL_SIZE * SEQ_INDEX(request->seq_num), 0, - MHBA_CTRL_SIZE); + memset(team->node.storage + MHBA_CTRL_SIZE * SEQ_INDEX(request->seq_num), 0, MHBA_CTRL_SIZE); - task->state = XCCL_TASK_STATE_INPROGRESS; - xccl_coll_op_args_t coll = { - .coll_type = XCCL_BARRIER, - .alg.set_by_user = 0, - }; - //todo create special barrier to support multiple parallel ops - with seq_id - team->net.ucx_team->ctx->lib->collective_init(&coll, &request->barrier_req, - team->net.ucx_team); - team->net.ucx_team->ctx->lib->collective_post(request->barrier_req); + task->state = XCCL_TASK_STATE_COMPLETED; + + team->inter_node_barrier[team->net.sbgp->group_rank] = request->seq_num; + for(i=0; inet.net_size;i++){ + xccl_status_t status = send_block_data(team->net.qps[i], (uintptr_t)team->inter_node_barrier_mr->addr+team->net.sbgp->group_rank*sizeof(int) , sizeof(int), + team->inter_node_barrier_mr->lkey, + team->net.remote_ctrl[i].barrier_addr+sizeof(int)*team->net.sbgp->group_rank, team->net.remote_ctrl[i].barrier_rkey, 0, 0); + if (status != XCCL_OK) { + xccl_mhba_error("Failed sending barrier notice"); + return status; + } + } xccl_task_enqueue(task->schedule->tl_ctx->pq, task); return XCCL_OK; } -xccl_status_t xccl_mhba_asr_barrier_progress(xccl_coll_task_t *task) +xccl_status_t xccl_mhba_asr_barrier_progress(xccl_coll_task_t *task) //todo not needed { - xccl_mhba_task_t *self = ucs_derived_of(task, xccl_mhba_task_t); - xccl_mhba_coll_req_t *request = self->req; - xccl_mhba_team_t *team = request->team; - - if (XCCL_OK == - team->net.ucx_team->ctx->lib->collective_test(request->barrier_req)) { - team->net.ucx_team->ctx->lib->collective_finalize(request->barrier_req); - task->state = XCCL_TASK_STATE_COMPLETED; - } + task->state = XCCL_TASK_STATE_COMPLETED; return XCCL_OK; } -static inline xccl_status_t send_block_data(struct ibv_qp *qp, - uint64_t src_addr, - uint32_t msg_size, uint32_t lkey, - uint64_t remote_addr, uint32_t rkey, - int send_flags, int with_imm) -{ - struct ibv_send_wr *bad_wr; - struct ibv_sge list = { - .addr = src_addr, - .length = msg_size, - .lkey = lkey, - }; - - struct ibv_send_wr wr = { - .wr_id = 1, - .sg_list = &list, - .num_sge = 1, - .opcode = with_imm ? IBV_WR_RDMA_WRITE_WITH_IMM : IBV_WR_RDMA_WRITE, - .send_flags = send_flags, - .wr.rdma.remote_addr = remote_addr, - .wr.rdma.rkey = rkey, - }; - if (ibv_post_send(qp, &wr, &bad_wr)) { - xccl_mhba_error("failed to post send"); - return XCCL_ERR_NO_MESSAGE; - } - return XCCL_OK; -} static inline xccl_status_t send_atomic(struct ibv_qp *qp, uint64_t remote_addr, uint32_t rkey, xccl_mhba_team_t *team, @@ -337,70 +334,76 @@ xccl_mhba_send_blocks_start_with_transpose(xccl_coll_task_t *task) int i, j, k, dest_rank, rank, n_compl, ret; uint64_t src_addr, remote_addr; struct ibv_wc transpose_completion[1]; + int counter = 0; xccl_status_t status; xccl_mhba_debug("send blocks start"); task->state = XCCL_TASK_STATE_INPROGRESS; rank = team->net.rank_map[team->net.sbgp->group_rank]; - for (i = 0; i < net_size; i++) { - dest_rank = team->net.rank_map[i]; - //send all blocks from curr node to some ARR - for (j = 0; j < xccl_round_up(node_size, block_size); j++) { - for (k = 0; k < xccl_round_up(node_size, block_size); k++) { - src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank + - col_msgsize * j + block_msgsize * k); - remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank + - block_msgsize * j + col_msgsize * k); - prepost_dummy_recv(team->node.umr_qp, 1); - // SW Transpose - status = send_block_data( - team->node.umr_qp, src_addr, block_msgsize, - team->node.team_send_mkey->lkey, - (uintptr_t)request->transpose_buf_mr->addr, - request->transpose_buf_mr->rkey, IBV_SEND_SIGNALED, 1); - if (status != XCCL_OK) { - xccl_mhba_error( - "Failed sending block to transpose buffer[%d,%d,%d]", i, j, k); - return status; - } - n_compl = 0; - while (n_compl != 2) { - ret = ibv_poll_cq(team->node.umr_cq, 1, transpose_completion); - if (ret > 0) { - if (transpose_completion[0].status != IBV_WC_SUCCESS) { + while(counter < net_size) { + for (i = 0; i < net_size; i++) { + if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) { + team->inter_node_barrier_flag[i] = 1; + dest_rank = team->net.rank_map[i]; + //send all blocks from curr node to some ARR + for (j = 0; j < xccl_round_up(node_size, block_size); j++) { + for (k = 0; k < xccl_round_up(node_size, block_size); k++) { + src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank + + col_msgsize * j + block_msgsize * k); + remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank + + block_msgsize * j + col_msgsize * k); + prepost_dummy_recv(team->node.umr_qp, 1); + // SW Transpose + status = send_block_data( + team->node.umr_qp, src_addr, block_msgsize, + team->node.team_send_mkey->lkey, + (uintptr_t) request->transpose_buf_mr->addr, + request->transpose_buf_mr->rkey, IBV_SEND_SIGNALED, 1); + if (status != XCCL_OK) { xccl_mhba_error( - "local copy for transpose CQ returned " - "completion with status %s (%d)", - ibv_wc_status_str(transpose_completion[0].status), - transpose_completion[0].status); - return XCCL_ERR_NO_MESSAGE; + "Failed sending block to transpose buffer[%d,%d,%d]", i, j, k); + return status; + } + n_compl = 0; + while (n_compl != 2) { + ret = ibv_poll_cq(team->node.umr_cq, 1, transpose_completion); + if (ret > 0) { + if (transpose_completion[0].status != IBV_WC_SUCCESS) { + xccl_mhba_error( + "local copy for transpose CQ returned " + "completion with status %s (%d)", + ibv_wc_status_str(transpose_completion[0].status), + transpose_completion[0].status); + return XCCL_ERR_NO_MESSAGE; + } + n_compl++; + } } - n_compl++; + transpose_square_mat(request->transpose_buf_mr->addr, + block_size, request->args.buffer_info.len, + request->tmp_transpose_buf); + status = send_block_data( + team->net.qps[i], + (uintptr_t) request->transpose_buf_mr->addr, block_msgsize, + request->transpose_buf_mr->lkey, remote_addr, + team->net.rkeys[i], IBV_SEND_SIGNALED, 0); + if (status != XCCL_OK) { + xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k); + return status; + } + while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {} } } - transpose_square_mat(request->transpose_buf_mr->addr, - block_size, request->args.buffer_info.len, - request->tmp_transpose_buf); - status = send_block_data( - team->net.qps[i], - (uintptr_t)request->transpose_buf_mr->addr, block_msgsize, - request->transpose_buf_mr->lkey, remote_addr, - team->net.rkeys[i], IBV_SEND_SIGNALED, 0); - if (status != XCCL_OK) { - xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k); - return status; - } - while (!ibv_poll_cq(team->net.cq, 1, transpose_completion)) {} + counter += 1; } } } - for (i = 0; i < net_size; i++) { status = send_atomic(team->net.qps[i], - (uintptr_t)team->net.remote_ctrl[i].addr + + (uintptr_t)team->net.remote_ctrl[i].ctrl_addr + (index * MHBA_CTRL_SIZE), - team->net.remote_ctrl[i].rkey, team, request); + team->net.remote_ctrl[i].ctrl_rkey, team, request); if (status != XCCL_OK) { xccl_mhba_error("Failed sending atomic to node [%d]", i); return status; @@ -426,6 +429,7 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task) int col_msgsize = len * block_size * node_size; int block_msgsize = SQUARED(block_size) * len; int i, j, k, dest_rank, rank; + int counter = 0; uint64_t src_addr, remote_addr; xccl_status_t status; @@ -433,33 +437,39 @@ static xccl_status_t xccl_mhba_send_blocks_start(xccl_coll_task_t *task) task->state = XCCL_TASK_STATE_INPROGRESS; rank = team->net.rank_map[team->net.sbgp->group_rank]; - for (i = 0; i < net_size; i++) { - dest_rank = team->net.rank_map[i]; - //send all blocks from curr node to some ARR - for (j = 0; j < xccl_round_up(node_size, block_size); j++) { - for (k = 0; k < xccl_round_up(node_size, block_size); k++) { - src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank + - col_msgsize * j + block_msgsize * k); - remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank + - block_msgsize * j + col_msgsize * k); - - status = send_block_data(team->net.qps[i], src_addr, block_msgsize, - team->node.team_send_mkey->lkey, - remote_addr, team->net.rkeys[i], 0, 0); + while(counter < net_size) { + for (i = 0; i < net_size; i++) { + if (team->inter_node_barrier[i] == request->seq_num && !team->inter_node_barrier_flag[i]) { + team->inter_node_barrier_flag[i] = 1; + dest_rank = team->net.rank_map[i]; + //send all blocks from curr node to some ARR + for (j = 0; j < xccl_round_up(node_size, block_size); j++) { + for (k = 0; k < xccl_round_up(node_size, block_size); k++) { + src_addr = (uintptr_t)(op_msgsize * index + node_msgsize * dest_rank + + col_msgsize * j + block_msgsize * k); + remote_addr = (uintptr_t)(op_msgsize * index + node_msgsize * rank + + block_msgsize * j + col_msgsize * k); + + status = send_block_data(team->net.qps[i], src_addr, block_msgsize, + team->node.team_send_mkey->lkey, + remote_addr, team->net.rkeys[i], 0, 0); + if (status != XCCL_OK) { + xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k); + return status; + } + } + } + status = send_atomic(team->net.qps[i], + (uintptr_t) team->net.remote_ctrl[i].ctrl_addr + + (index * MHBA_CTRL_SIZE), + team->net.remote_ctrl[i].ctrl_rkey, team, request); if (status != XCCL_OK) { - xccl_mhba_error("Failed sending block [%d,%d,%d]", i, j, k); + xccl_mhba_error("Failed sending atomic to node [%d]", i); return status; } + counter += 1; } } - status = send_atomic(team->net.qps[i], - (uintptr_t)team->net.remote_ctrl[i].addr + - (index * MHBA_CTRL_SIZE), - team->net.remote_ctrl[i].rkey, team, request); - if (status != XCCL_OK) { - xccl_mhba_error("Failed sending atomic to node [%d]", i); - return status; - } } xccl_task_enqueue(task->schedule->tl_ctx->pq, task); return XCCL_OK; @@ -580,6 +590,7 @@ xccl_status_t xccl_mhba_alltoall_init(xccl_coll_op_args_t *coll_args, xccl_mhba_fanout_start; request->tasks[1].super.progress = xccl_mhba_fanout_progress; } else { + memset(team->inter_node_barrier_flag,0,sizeof(int)*team->net.net_size); request->tasks[1].super.handlers[XCCL_EVENT_COMPLETED] = xccl_mhba_asr_barrier_start; request->tasks[1].super.progress = xccl_mhba_asr_barrier_progress; diff --git a/src/team_lib/mhba/xccl_mhba_lib.c b/src/team_lib/mhba/xccl_mhba_lib.c index d70b5cb..52e2eae 100644 --- a/src/team_lib/mhba/xccl_mhba_lib.c +++ b/src/team_lib/mhba/xccl_mhba_lib.c @@ -24,11 +24,11 @@ static ucs_config_field_t xccl_tl_mhba_context_config_table[] = { ucs_offsetof(xccl_tl_mhba_context_config_t, devices), UCS_CONFIG_TYPE_STRING_ARRAY}, - {"TRANSPOSE", "1", "Boolean - with transpose or not", + {"TRANSPOSE", "0", "Boolean - with transpose or not", ucs_offsetof(xccl_tl_mhba_context_config_t, transpose), UCS_CONFIG_TYPE_UINT}, - {"TRANSPOSE_HW_LIMITATIONS", "0", + {"TRANSPOSE_HW_LIMITATIONS", "1", "Boolean - with transpose hw limitations or not", ucs_offsetof(xccl_tl_mhba_context_config_t, transpose_hw_limitations), UCS_CONFIG_TYPE_UINT}, //todo change to 1 in production diff --git a/src/team_lib/mhba/xccl_mhba_lib.h b/src/team_lib/mhba/xccl_mhba_lib.h index 877f984..7e79420 100644 --- a/src/team_lib/mhba/xccl_mhba_lib.h +++ b/src/team_lib/mhba/xccl_mhba_lib.h @@ -129,8 +129,10 @@ typedef struct xccl_mhba_net { struct ibv_cq *cq; struct ibv_mr *ctrl_mr; struct { - void *addr; - uint32_t rkey; + void *ctrl_addr; + uint32_t ctrl_rkey; + void *barrier_addr; + uint32_t barrier_rkey; } * remote_ctrl; uint32_t *rkeys; xccl_tl_team_t *ucx_team; @@ -143,6 +145,10 @@ typedef struct xccl_mhba_team { uint64_t max_msg_size; xccl_mhba_node_t node; xccl_mhba_net_t net; + int* inter_node_barrier; + int* inter_node_barrier_flag; + struct ibv_mr *inter_node_barrier_mr; + struct ibv_mr **net_barrier_mr; int sequence_number; int op_busy[MAX_OUTSTANDING_OPS]; int cq_completions[MAX_OUTSTANDING_OPS]; diff --git a/src/team_lib/mhba/xccl_mhba_team.c b/src/team_lib/mhba/xccl_mhba_team.c index 4c50958..d681fb1 100644 --- a/src/team_lib/mhba/xccl_mhba_team.c +++ b/src/team_lib/mhba/xccl_mhba_team.c @@ -308,7 +308,7 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, } // for each ASR - qp num, in addition to port lid, ctrl segment rkey and address, recieve mkey rkey local_data_size = (net_size * sizeof(uint32_t)) + sizeof(uint32_t) + - 2 * sizeof(uint32_t) + sizeof(void *); + 3 * sizeof(uint32_t) + 2*sizeof(void *); //todo make concurrent local_data = malloc(local_data_size); if (!local_data) { xccl_mhba_error("failed to allocate local data"); @@ -361,6 +361,22 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, local_data[net_size + 4] = mhba_team->node.team_recv_mkey->rkey; + mhba_team->inter_node_barrier = (int*) malloc(sizeof(int)*net_size); + mhba_team->inter_node_barrier_flag = (int*) malloc(sizeof(int)*net_size); + int p; + for(p=0;pinter_node_barrier[p] = -1; + } + mhba_team->inter_node_barrier_mr = ibv_reg_mr(mhba_team->node.shared_pd, mhba_team->inter_node_barrier, + sizeof(uint32_t)*net_size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!mhba_team->inter_node_barrier_mr) { + xccl_mhba_error("Failed to register memory"); + return UCS_ERR_NO_MESSAGE; + } + local_data[net_size + 5] = mhba_team->inter_node_barrier_mr->rkey; + *((uint64_t *)&local_data[net_size + 6]) = + (uint64_t)(uintptr_t)mhba_team->inter_node_barrier_mr->addr; + xccl_sbgp_oob_allgather(local_data, global_data, local_data_size, net, params->oob); mhba_team->net.rkeys = (uint32_t *)malloc(sizeof(uint32_t) * net_size); @@ -370,9 +386,12 @@ xccl_status_t xccl_mhba_team_create_post(xccl_tl_context_t *context, xccl_mhba_qp_connect(mhba_team->net.qps[i], remote_data[net->group_rank], remote_data[net_size], ctx->ib_port); - mhba_team->net.remote_ctrl[i].rkey = remote_data[net_size + 1]; - mhba_team->net.remote_ctrl[i].addr = + mhba_team->net.remote_ctrl[i].ctrl_rkey = remote_data[net_size + 1]; + mhba_team->net.remote_ctrl[i].ctrl_addr = (void *)(uintptr_t)(*((uint64_t *)&remote_data[net_size + 2])); + mhba_team->net.remote_ctrl[i].barrier_rkey = remote_data[net_size + 5]; + mhba_team->net.remote_ctrl[i].barrier_addr = + (void *)(uintptr_t)(*((uint64_t *)&remote_data[net_size + 6])); mhba_team->net.rkeys[i] = remote_data[net_size + 4]; } xccl_sbgp_oob_barrier(net, params->oob); @@ -514,6 +533,9 @@ xccl_status_t xccl_mhba_team_destroy(xccl_tl_team_t *team) ibv_dereg_mr(mhba_team->transpose_buf_mr); free(mhba_team->transpose_buf); } + ibv_dereg_mr(mhba_team->inter_node_barrier_mr); + free(mhba_team->inter_node_barrier); + free(mhba_team->inter_node_barrier_flag); } free(team); return status;