diff --git a/src/components/tl/ucp/alltoall/alltoall_ca.c b/src/components/tl/ucp/alltoall/alltoall_ca.c index a9c3bc5a5c..9511da3101 100644 --- a/src/components/tl/ucp/alltoall/alltoall_ca.c +++ b/src/components/tl/ucp/alltoall/alltoall_ca.c @@ -11,12 +11,16 @@ #include "utils/ucc_math.h" #include "tl_ucp_sendrecv.h" +/* update when pinger rtt complete */ +#define MAGIC_NUMBER 1.0 + void ucc_tl_ucp_alltoall_onesided_ca_progress(ucc_coll_task_t *ctask); ucc_status_t ucc_tl_ucp_alltoall_onesided_ca_start(ucc_coll_task_t *ctask) { ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t); ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); ptrdiff_t src = (ptrdiff_t)TASK_ARGS(task).src.info.buffer; ptrdiff_t dest = (ptrdiff_t)TASK_ARGS(task).dst.info.buffer; size_t nelems = TASK_ARGS(task).src.info.count; @@ -24,6 +28,10 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_ca_start(ucc_coll_task_t *ctask) ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); ucc_rank_t start = (grank + 1) % gsize; long * pSync = TASK_ARGS(task).global_work_buffer; + int revisit[128] = {0}; + int nr_revisit = 0; + int nr_revisit_max = 0; + pinger_rtt_t rtt; ucc_rank_t peer; ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); @@ -33,11 +41,34 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_ca_start(ucc_coll_task_t *ctask) /* maybe have a list of processes to send to, cut them out of the process list */ for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) { - UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems), - (void *)dest, nelems, peer, team, task), - task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task, - out); + pinger_query(ctx->pinger, ctx->pinger_peer[peer], &rtt); + if (rtt <= MAGIC_NUMBER) { + UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems), + (void *)dest, nelems, peer, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task, + out); + } else { + revisit[nr_revisit++] = peer; + } + } + + while (nr_revisit) { + nr_revisit_max = nr_revisit; + nr_revisit = 0; + for (int i = 0; i < nr_revisit_max; i++) { + peer = revisit[i]; + pinger_query(ctx->pinger, ctx->pinger_peer[peer], &rtt); + if (rtt <= MAGIC_NUMBER) { + UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems), + (void *)dest, nelems, peer, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task, + out); + } else { + revisit[nr_revisit++] = peer; + } + } } return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);