-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TL/MLX5: generate schedule for zcopy allgather #1059
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -270,6 +270,152 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) | |||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
static inline ucc_status_t | ||||||||||||||||
ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm, | ||||||||||||||||
ucc_tl_mlx5_mcast_coll_req_t *req) | ||||||||||||||||
{ | ||||||||||||||||
if ((req->concurreny_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) || | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not required, only suggesting: since the conditions are rather independent, I would separate then into separate |
||||||||||||||||
(comm->commsize % req->concurreny_level != 0) || | ||||||||||||||||
(req->length % comm->max_per_packet != 0)) { | ||||||||||||||||
tl_warn(comm->lib, "Pipelined mcast allgather not supported: " | ||||||||||||||||
"num_packets %d mcast_prepost_bucket_size %d " | ||||||||||||||||
"length %ld max_per_packet %d " | ||||||||||||||||
"team size %d concurreny_level %d", | ||||||||||||||||
req->num_packets, req->mcast_prepost_bucket_size, req->length, | ||||||||||||||||
comm->max_per_packet, comm->commsize, req->concurreny_level); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should |
||||||||||||||||
return UCC_ERR_NOT_SUPPORTED; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
if (req->mcast_prepost_bucket_size * req->concurreny_level * 2 > comm->params.rx_depth) { | ||||||||||||||||
tl_warn(comm->lib, "Pipelined mcast allgather not supported: " | ||||||||||||||||
"either reduce prepost_bucket_size or mcast group " | ||||||||||||||||
"count or increase recv queue size " | ||||||||||||||||
"mcast_prepost_bucket_size %d concurreny_level %d " | ||||||||||||||||
"rx_depth %d", | ||||||||||||||||
Comment on lines
+291
to
+294
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
req->mcast_prepost_bucket_size, req->concurreny_level, | ||||||||||||||||
comm->params.rx_depth); | ||||||||||||||||
return UCC_ERR_NOT_SUPPORTED; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
return UCC_OK; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
/* | ||||||||||||||||
* at each stage half of the mcast groups are ready for receiving mcast | ||||||||||||||||
* packets while the other half are getting prepared by preposting recv | ||||||||||||||||
* buffers | ||||||||||||||||
*/ | ||||||||||||||||
static inline ucc_status_t | ||||||||||||||||
ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm, | ||||||||||||||||
ucc_tl_mlx5_mcast_coll_req_t *req) | ||||||||||||||||
{ | ||||||||||||||||
ucc_tl_mlx5_mcast_reg_t *reg = NULL; | ||||||||||||||||
ucc_rank_t root = 0; | ||||||||||||||||
int offset = 0; | ||||||||||||||||
ucc_status_t status; | ||||||||||||||||
ucc_rank_t j, i; | ||||||||||||||||
int total_steps; | ||||||||||||||||
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *new_sched; | ||||||||||||||||
|
||||||||||||||||
ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled); | ||||||||||||||||
|
||||||||||||||||
req->concurreny_level = comm->mcast_group_count / 2; | ||||||||||||||||
req->concurreny_level = ucc_min(req->concurreny_level, ONE_SIDED_MAX_CONCURRENT_LEVEL); | ||||||||||||||||
req->concurreny_level = ucc_min(req->concurreny_level, comm->commsize); | ||||||||||||||||
|
||||||||||||||||
if (req->concurreny_level == 0) { | ||||||||||||||||
tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather"); | ||||||||||||||||
return UCC_ERR_NOT_SUPPORTED; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
if (comm->allgather_comm.mcast_prepost_bucket_size > req->num_packets) { | ||||||||||||||||
req->mcast_prepost_bucket_size = req->num_packets; | ||||||||||||||||
} else { | ||||||||||||||||
req->mcast_prepost_bucket_size = comm->allgather_comm.mcast_prepost_bucket_size; | ||||||||||||||||
} | ||||||||||||||||
Comment on lines
+332
to
+336
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req); | ||||||||||||||||
if (status != UCC_OK) { | ||||||||||||||||
return status; | ||||||||||||||||
} | ||||||||||||||||
Comment on lines
+340
to
+341
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent issue |
||||||||||||||||
|
||||||||||||||||
/* calculate the schedule and details of what we should | ||||||||||||||||
* mcast and prepost to which mcast group at each stage*/ | ||||||||||||||||
total_steps = req->num_packets * (comm->commsize / req->concurreny_level) | ||||||||||||||||
/ req->mcast_prepost_bucket_size + 1; | ||||||||||||||||
|
||||||||||||||||
new_sched = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) * total_steps, "sched"); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I would rename the variable |
||||||||||||||||
if (!new_sched) { | ||||||||||||||||
tl_warn(comm->lib, "cannot allocate memory for schedule list"); | ||||||||||||||||
return UCC_ERR_NO_MEMORY; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
/* generate schedule */ | ||||||||||||||||
for (i = 0; i < total_steps; i++) { | ||||||||||||||||
ucc_assert(root < comm->commsize); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it useful? (just checking) |
||||||||||||||||
if (i < total_steps - 1) { | ||||||||||||||||
for (j = 0; j < req->concurreny_level; j++) { | ||||||||||||||||
new_sched[i].prepost_buf_op[j].group_id = j + req->concurreny_level * (i % 2); | ||||||||||||||||
new_sched[i].prepost_buf_op[j].offset = offset * comm->max_per_packet; | ||||||||||||||||
new_sched[i].prepost_buf_op[j].root = root + j; | ||||||||||||||||
new_sched[i].prepost_buf_op[j].count = req->mcast_prepost_bucket_size; | ||||||||||||||||
} | ||||||||||||||||
} else { | ||||||||||||||||
new_sched[i].prepost_buf_op_done = 1; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
if (i > 0) { | ||||||||||||||||
for (j = 0; j < req->concurreny_level; j++) { | ||||||||||||||||
new_sched[i].multicast_op[j].group_id = new_sched[i - 1].prepost_buf_op[j].group_id; | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems that it way exceeds Col 80. Have you run @Sergei-Lebedev @janjust do we want this formatting style to be compulsory? If yes we should have the CI check for it. |
||||||||||||||||
new_sched[i].multicast_op[j].offset = new_sched[i - 1].prepost_buf_op[j].offset; | ||||||||||||||||
new_sched[i].multicast_op[j].offset_left = new_sched[i - 1].prepost_buf_op[j].offset; | ||||||||||||||||
new_sched[i].multicast_op[j].root = new_sched[i - 1].prepost_buf_op[j].root; | ||||||||||||||||
new_sched[i].multicast_op[j].to_send_left = new_sched[i - 1].prepost_buf_op[j].count; | ||||||||||||||||
new_sched[i].multicast_op[j].to_recv = new_sched[i - 1].prepost_buf_op[j].count; | ||||||||||||||||
new_sched[i].to_recv += new_sched[i].multicast_op[j].to_recv; | ||||||||||||||||
if (new_sched[i].multicast_op[j].root == comm->rank) { | ||||||||||||||||
new_sched[i].to_send += new_sched[i].multicast_op[j].to_send_left; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
if (!new_sched[i].to_send || !new_sched[i].to_recv) { | ||||||||||||||||
new_sched[i].multicast_op_done = 1; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
offset += req->mcast_prepost_bucket_size; | ||||||||||||||||
|
||||||||||||||||
if (offset == req->num_packets) { | ||||||||||||||||
offset = 0; | ||||||||||||||||
root = (root + req->concurreny_level) % comm->commsize; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
tl_trace(comm->lib, | ||||||||||||||||
"generated the schedule for pipelined zero copy allgather with total_steps %d", | ||||||||||||||||
total_steps); | ||||||||||||||||
new_sched->total_steps = total_steps; | ||||||||||||||||
req->total_steps = total_steps; | ||||||||||||||||
req->ag_schedule = new_sched; | ||||||||||||||||
Comment on lines
+398
to
+400
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it redundant? |
||||||||||||||||
tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize); | ||||||||||||||||
ucc_assert(req->recv_rreg == NULL); | ||||||||||||||||
|
||||||||||||||||
status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length * | ||||||||||||||||
comm->commsize, ®); | ||||||||||||||||
if (UCC_OK != status) { | ||||||||||||||||
tl_warn(comm->lib, "unable to register receive buffer %p of size %ld", | ||||||||||||||||
req->rptr, req->length * comm->commsize); | ||||||||||||||||
ucc_free(new_sched); | ||||||||||||||||
return status; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
req->recv_rreg = reg; | ||||||||||||||||
req->recv_mr = reg->mr; | ||||||||||||||||
|
||||||||||||||||
return UCC_OK; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) | ||||||||||||||||
{ | ||||||||||||||||
ucc_coll_task_t *coll_task = &(task->super); | ||||||||||||||||
|
@@ -357,6 +503,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task) | |||||||||||||||
req->to_send = req->num_packets; | ||||||||||||||||
req->to_recv = comm->commsize * req->num_packets; | ||||||||||||||||
|
||||||||||||||||
if (comm->allgather_comm.truly_zero_copy_allgather_enabled) { | ||||||||||||||||
status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req); | ||||||||||||||||
if (UCC_OK != status) { | ||||||||||||||||
return status; | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. memory leak of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also the memory needs to be deregistered |
||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
comm->allgather_comm.coll_counter++; | ||||||||||||||||
|
||||||||||||||||
task->coll_mcast.req_handle = req; | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { | |
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable), | ||
UCC_CONFIG_TYPE_BOOL}, | ||
|
||
{"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does "truly" mean in this context? |
||
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled), | ||
UCC_CONFIG_TYPE_BOOL}, | ||
|
||
{"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline" | ||
" in truly zero copy mcast allgather design", | ||
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size), | ||
UCC_CONFIG_TYPE_INT}, | ||
|
||
{NULL}}; | ||
|
||
static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: concurrency_level