Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: generate schedule for zcopy allgather #1059

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec {
int max_eager;
int cuda_mem_enabled;
int one_sided_reliability_enable;
int truly_zero_copy_allgather_enabled;
int mcast_prepost_bucket_size;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -276,6 +278,8 @@ typedef struct ucc_tl_mlx5_mcast_allgather_comm {
uint32_t coll_counter;
uint32_t max_num_packets;
uint32_t max_push_send;
uint8_t truly_zero_copy_allgather_enabled;
uint32_t mcast_prepost_bucket_size;
} ucc_tl_mlx5_mcast_allgather_comm_t;

typedef struct ucc_tl_mlx5_mcast_bcast_comm {
Expand Down Expand Up @@ -434,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_memory_type_t buf_mem_type;
enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme;
uint32_t ag_counter;
int concurreny_level;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: concurrency_level

int mcast_prepost_bucket_size;
int state;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule;
int total_steps;
Expand Down
153 changes: 153 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,152 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
}
}

static inline ucc_status_t
ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
if ((req->concurreny_level % 2 == 0 && req->num_packets % req->mcast_prepost_bucket_size != 0) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, only suggesting: since the conditions are rather independent, I would separate then into separate if blocks so that it is more readable and we can print a more precise and helpful warn message to clearly indicate what is the reason for the failure (as you did for the second if block).

(comm->commsize % req->concurreny_level != 0) ||
(req->length % comm->max_per_packet != 0)) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"num_packets %d mcast_prepost_bucket_size %d "
"length %ld max_per_packet %d "
"team size %d concurreny_level %d",
req->num_packets, req->mcast_prepost_bucket_size, req->length,
comm->max_per_packet, comm->commsize, req->concurreny_level);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should commsize be renamed to comm_size like the other variables?

return UCC_ERR_NOT_SUPPORTED;
}

if (req->mcast_prepost_bucket_size * req->concurreny_level * 2 > comm->params.rx_depth) {
tl_warn(comm->lib, "Pipelined mcast allgather not supported: "
"either reduce prepost_bucket_size or mcast group "
"count or increase recv queue size "
"mcast_prepost_bucket_size %d concurreny_level %d "
"rx_depth %d",
Comment on lines +291 to +294
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"either reduce prepost_bucket_size or mcast group "
"count or increase recv queue size "
"mcast_prepost_bucket_size %d concurreny_level %d "
"rx_depth %d",
"we only support the case prepost_bucket_size * concurreny_level * 2 > rx_depth, "
"but got: prepost_bucket_size=%d, concurreny_level=%d, "
"rx_depth=%d".

req->mcast_prepost_bucket_size, req->concurreny_level,
comm->params.rx_depth);
return UCC_ERR_NOT_SUPPORTED;
}

return UCC_OK;
}


/*
* at each stage half of the mcast groups are ready for receiving mcast
* packets while the other half are getting prepared by preposting recv
* buffers
*/
static inline ucc_status_t
ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_rank_t root = 0;
int offset = 0;
ucc_status_t status;
ucc_rank_t j, i;
int total_steps;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *new_sched;

ucc_assert(comm->allgather_comm.truly_zero_copy_allgather_enabled);

req->concurreny_level = comm->mcast_group_count / 2;
req->concurreny_level = ucc_min(req->concurreny_level, ONE_SIDED_MAX_CONCURRENT_LEVEL);
req->concurreny_level = ucc_min(req->concurreny_level, comm->commsize);

if (req->concurreny_level == 0) {
tl_warn(comm->lib, "not enough concurreny level to enable zcopy pipeline allgather");
return UCC_ERR_NOT_SUPPORTED;
}

if (comm->allgather_comm.mcast_prepost_bucket_size > req->num_packets) {
req->mcast_prepost_bucket_size = req->num_packets;
} else {
req->mcast_prepost_bucket_size = comm->allgather_comm.mcast_prepost_bucket_size;
}
Comment on lines +332 to +336
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (comm->allgather_comm.mcast_prepost_bucket_size > req->num_packets) {
req->mcast_prepost_bucket_size = req->num_packets;
} else {
req->mcast_prepost_bucket_size = comm->allgather_comm.mcast_prepost_bucket_size;
}
req->mcast_prepost_bucket_size = ucc_min(req->num_packets, comm->allgather_comm.mcast_prepost_bucket_size);


status = ucc_tl_mlx5_mcast_validate_zero_copy_allgather_params(comm, req);
if (status != UCC_OK) {
return status;
}
Comment on lines +340 to +341
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent issue


/* calculate the schedule and details of what we should
* mcast and prepost to which mcast group at each stage*/
total_steps = req->num_packets * (comm->commsize / req->concurreny_level)
/ req->mcast_prepost_bucket_size + 1;

new_sched = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_pipelined_ag_schedule_t) * total_steps, "sched");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would rename the variable new_sched to schedule

if (!new_sched) {
tl_warn(comm->lib, "cannot allocate memory for schedule list");
return UCC_ERR_NO_MEMORY;
}

/* generate schedule */
for (i = 0; i < total_steps; i++) {
ucc_assert(root < comm->commsize);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it useful? (just checking)

if (i < total_steps - 1) {
for (j = 0; j < req->concurreny_level; j++) {
new_sched[i].prepost_buf_op[j].group_id = j + req->concurreny_level * (i % 2);
new_sched[i].prepost_buf_op[j].offset = offset * comm->max_per_packet;
new_sched[i].prepost_buf_op[j].root = root + j;
new_sched[i].prepost_buf_op[j].count = req->mcast_prepost_bucket_size;
}
} else {
new_sched[i].prepost_buf_op_done = 1;
}

if (i > 0) {
for (j = 0; j < req->concurreny_level; j++) {
new_sched[i].multicast_op[j].group_id = new_sched[i - 1].prepost_buf_op[j].group_id;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that it way exceeds Col 80. Have you run git-clang-format?

@Sergei-Lebedev @janjust do we want this formatting style to be compulsory? If yes we should have the CI check for it.

new_sched[i].multicast_op[j].offset = new_sched[i - 1].prepost_buf_op[j].offset;
new_sched[i].multicast_op[j].offset_left = new_sched[i - 1].prepost_buf_op[j].offset;
new_sched[i].multicast_op[j].root = new_sched[i - 1].prepost_buf_op[j].root;
new_sched[i].multicast_op[j].to_send_left = new_sched[i - 1].prepost_buf_op[j].count;
new_sched[i].multicast_op[j].to_recv = new_sched[i - 1].prepost_buf_op[j].count;
new_sched[i].to_recv += new_sched[i].multicast_op[j].to_recv;
if (new_sched[i].multicast_op[j].root == comm->rank) {
new_sched[i].to_send += new_sched[i].multicast_op[j].to_send_left;
}
}
}

if (!new_sched[i].to_send || !new_sched[i].to_recv) {
new_sched[i].multicast_op_done = 1;
}

offset += req->mcast_prepost_bucket_size;

if (offset == req->num_packets) {
offset = 0;
root = (root + req->concurreny_level) % comm->commsize;
}
}

tl_trace(comm->lib,
"generated the schedule for pipelined zero copy allgather with total_steps %d",
total_steps);
new_sched->total_steps = total_steps;
req->total_steps = total_steps;
req->ag_schedule = new_sched;
Comment on lines +398 to +400
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it redundant?

tl_trace(comm->lib, "registering recv buf of size %ld", req->length * comm->commsize);
ucc_assert(req->recv_rreg == NULL);

status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->rptr, req->length *
comm->commsize, &reg);
if (UCC_OK != status) {
tl_warn(comm->lib, "unable to register receive buffer %p of size %ld",
req->rptr, req->length * comm->commsize);
ucc_free(new_sched);
return status;
}

req->recv_rreg = reg;
req->recv_mr = reg->mr;

return UCC_OK;
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
ucc_coll_task_t *coll_task = &(task->super);
Expand Down Expand Up @@ -357,6 +503,13 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
req->to_send = req->num_packets;
req->to_recv = comm->commsize * req->num_packets;

if (comm->allgather_comm.truly_zero_copy_allgather_enabled) {
status = ucc_tl_mlx5_mcast_prepare_zero_copy_allgather(comm, req);
if (UCC_OK != status) {
return status;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory leak of req

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the memory needs to be deregistered

}
}

comm->allgather_comm.coll_counter++;

task->coll_mcast.req_handle = req;
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

memcpy(&comm->params, conf_params, sizeof(*conf_params));

comm->allgather_comm.mcast_prepost_bucket_size
= conf_params->mcast_prepost_bucket_size;
comm->allgather_comm.truly_zero_copy_allgather_enabled
= conf_params->truly_zero_copy_allgather_enabled;
comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable;
comm->bcast_comm.wsize = conf_params->wsize;
comm->allgather_comm.max_push_send = conf_params->max_push_send;
Expand Down
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_ALLGATHER_ENABLE", "1", "Enable truly zero copy allgather design for mcast",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "truly" mean in this context?

ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.truly_zero_copy_allgather_enabled),
UCC_CONFIG_TYPE_BOOL},

{"MCAST_ZERO_COPY_PREPOST_BUCKET_SIZE", "16", "Number of posted recvs during each stage of the pipeline"
" in truly zero copy mcast allgather design",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.mcast_prepost_bucket_size),
UCC_CONFIG_TYPE_INT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand Down
Loading