From 434f3dce0f70b0430f610418856c8b4e0408e048 Mon Sep 17 00:00:00 2001 From: Tomislavj Janjusic Date: Wed, 6 Jan 2021 14:52:27 -0800 Subject: [PATCH 1/2] Adding the XCCL DPU team, and DPU daemon Signed-off-by: Tomislavj Janjusic Co-authored-by: Artem Polyakov Co-authored-by: Sergey Lebedev --- configure.ac | 8 + contrib/dpu_daemon/Makefile | 20 + contrib/dpu_daemon/dpu_server.c | 171 +++++++ contrib/dpu_daemon/host_channel.c | 647 +++++++++++++++++++++++++ contrib/dpu_daemon/host_channel.h | 97 ++++ contrib/dpu_daemon/server_xccl.c | 206 ++++++++ contrib/dpu_daemon/server_xccl.h | 43 ++ src/Makefile.am | 3 + src/api/xccl_tls.h | 5 +- src/team_lib/dpu/Makefile.am | 24 + src/team_lib/dpu/xccl_dpu_lib.c | 662 ++++++++++++++++++++++++++ src/team_lib/dpu/xccl_dpu_lib.h | 82 ++++ src/team_lib/hier/xccl_hier_context.c | 3 +- 13 files changed, 1969 insertions(+), 2 deletions(-) create mode 100644 contrib/dpu_daemon/Makefile create mode 100644 contrib/dpu_daemon/dpu_server.c create mode 100644 contrib/dpu_daemon/host_channel.c create mode 100644 contrib/dpu_daemon/host_channel.h create mode 100644 contrib/dpu_daemon/server_xccl.c create mode 100644 contrib/dpu_daemon/server_xccl.h create mode 100644 src/team_lib/dpu/Makefile.am create mode 100644 src/team_lib/dpu/xccl_dpu_lib.c create mode 100644 src/team_lib/dpu/xccl_dpu_lib.h diff --git a/configure.ac b/configure.ac index ed1190b..bfcf1d2 100644 --- a/configure.ac +++ b/configure.ac @@ -117,6 +117,13 @@ AC_ARG_WITH([cuda], AM_CONDITIONAL([HAVE_CUDA], [test "x$cuda_happy" != xno]) AC_MSG_RESULT([CUDA support: $cuda_happy; $CUDA_CPPFLAGS $CUDA_LDFLAGS]) +AC_ARG_WITH([dpu], + AC_HELP_STRING([--with-dpu=yes/no], [Enable/Disable dpu team]), + [AS_IF([test "x$with_dpu" != "xno"], dpu_happy="yes", dpu_happy="no")], + [dpu_happy="no"]) +AM_CONDITIONAL([HAVE_DPU], [test "x$dpu_happy" != xno]) +AC_MSG_RESULT([DPU support: $dpu_happy]) + AM_CONDITIONAL([HAVE_NCCL], [false]) if test "x$cuda_happy" != xno; then m4_include([m4/nccl.m4]) @@ -136,6 +143,7 @@ AC_CONFIG_FILES([ src/team_lib/hier/Makefile src/team_lib/multirail/Makefile src/team_lib/nccl/Makefile + src/team_lib/dpu/Makefile src/utils/cuda/Makefile src/utils/cuda/kernels/Makefile test/Makefile diff --git a/contrib/dpu_daemon/Makefile b/contrib/dpu_daemon/Makefile new file mode 100644 index 0000000..975304f --- /dev/null +++ b/contrib/dpu_daemon/Makefile @@ -0,0 +1,20 @@ +# +# Copyright (c) 2020 Mellanox Technologies. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +CFLAGS = -I$(XCCL_PATH)/include -I$(UCX_PATH)/include +LDFLAGS = -L$(XCCL_PATH)/lib $(XCCL_PATH)/lib/libxccl.so $(UCX_PATH)/lib/libucs.so $(UCX_PATH)/lib/libucp.so -Wl,-rpath -Wl,$(XCCL_PATH)/lib -Wl,-rpath -Wl,$(XCCL_PATH)/lib + +rel: + mpicc -O3 -DNDEBUG -std=c11 $(CFLAGS) -o dpu_server dpu_server.c host_channel.c server_xccl.c $(LDFLAGS) + +dbg: + mpicc -O0 -g -std=c11 $(CFLAGS) -o dpu_server dpu_server.c host_channel.c server_xccl.c $(LDFLAGS) + +clean: + rm -f dpu_server diff --git a/contrib/dpu_daemon/dpu_server.c b/contrib/dpu_daemon/dpu_server.c new file mode 100644 index 0000000..ac6fce1 --- /dev/null +++ b/contrib/dpu_daemon/dpu_server.c @@ -0,0 +1,171 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include +#include +#include +#include + +#include "server_xccl.h" +#include "host_channel.h" + +#define MAX_THREADS 128 +typedef struct { + pthread_t id; + int idx, nthreads; + dpu_xccl_comm_t comm; + dpu_hc_t *hc; + unsigned int itt; +} thread_ctx_t; + +/* thread accisble data - split reader/writer */ +typedef struct { + volatile unsigned long g_itt; /* first cache line */ + volatile unsigned long pad[3]; /* pad to 64bytes */ + volatile unsigned long l_itt; /* second cache line */ + volatile unsigned long pad2[3]; /* pad to 64 bytes */ +} thread_sync_t; + +static thread_sync_t *thread_sync = NULL; + +void *dpu_worker(void *arg) +{ + int i = 0; + thread_ctx_t *ctx = (thread_ctx_t*)arg; + xccl_coll_req_h request; + + while(1) { + ctx->itt++; + if (ctx->idx > 0) { + while (thread_sync[ctx->idx].g_itt < ctx->itt) { + /* busy wait */ + } + } + else { + dpu_hc_wait(ctx->hc, ctx->itt); + for (i = 0; i < ctx->nthreads; i++) { + thread_sync[i].g_itt++; + } + } + + int offset, block; + int count = dpu_hc_get_count(ctx->hc); + int ready = 0; + + block = count / ctx->nthreads; + offset = block * ctx->idx; + if(ctx->idx < (count % ctx->nthreads)) { + offset += ctx->idx; + block++; + } else { + offset += (count % ctx->nthreads); + } + + xccl_coll_op_args_t coll = { + .field_mask = 0, + .coll_type = XCCL_ALLREDUCE, + .buffer_info = { + .src_buffer = ctx->hc->mem_segs.put.base + offset * sizeof(int), + .dst_buffer = ctx->hc->mem_segs.get.base + offset * sizeof(int), + .len = block * xccl_dt_size(dpu_hc_get_dtype(ctx->hc)), + }, + .reduce_info = { + .dt = dpu_hc_get_dtype(ctx->hc), + .op = dpu_hc_get_op(ctx->hc), + .count = block, + }, + .alg.set_by_user = 0, + .tag = 123, //todo + }; + + if (coll.reduce_info.op == XCCL_OP_UNSUPPORTED) { + break; + } + + XCCL_CHECK(xccl_collective_init(&coll, &request, ctx->comm.team)); + XCCL_CHECK(xccl_collective_post(request)); + while (XCCL_OK != xccl_collective_test(request)) { + xccl_context_progress(ctx->comm.ctx); + } + XCCL_CHECK(xccl_collective_finalize(request)); + + thread_sync[ctx->idx].l_itt++; + + if (ctx->idx == 0) { + while (ready != ctx->nthreads) { + ready = 0; + for (i = 0; i < ctx->nthreads; i++) { + if (thread_sync[i].l_itt == ctx->itt) { + ready++; + } + else { + break; + } + } + } + + dpu_hc_reply(ctx->hc, ctx->itt); + } + } + + return NULL; +} + +int main(int argc, char **argv) +{ + int nthreads = 0, i; + thread_ctx_t *tctx_pool = NULL; + dpu_xccl_global_t xccl_glob; + dpu_hc_t hc_b, *hc = &hc_b; + + if (argc < 2 ) { + printf("Need thread # as an argument\n"); + return 1; + } + nthreads = atoi(argv[1]); + if (MAX_THREADS < nthreads || 0 >= nthreads) { + printf("ERROR: bad thread #: %d\n", nthreads); + return 1; + } + printf("DPU daemon: Running with %d threads\n", nthreads); + tctx_pool = calloc(nthreads, sizeof(*tctx_pool)); + XCCL_CHECK(dpu_xccl_init(argc, argv, &xccl_glob)); + +// thread_sync = calloc(nthreads, sizeof(*thread_sync)); + thread_sync = aligned_alloc(64, nthreads * sizeof(*thread_sync)); + memset(thread_sync, 0, nthreads * sizeof(*thread_sync)); + + dpu_hc_init(hc); + dpu_hc_accept(hc); + + for(i = 0; i < nthreads; i++) { +// printf("Thread %d spawned!\n", i); + XCCL_CHECK(dpu_xccl_alloc_team(&xccl_glob, &tctx_pool[i].comm)); + + tctx_pool[i].idx = i; + tctx_pool[i].nthreads = nthreads; + tctx_pool[i].hc = hc; + tctx_pool[i].itt = 0; + + if (i < nthreads - 1) { + pthread_create(&tctx_pool[i].id, NULL, dpu_worker, + (void*)&tctx_pool[i]); + } + } + + /* The final DPU worker is executed in this context */ + dpu_worker((void*)&tctx_pool[i-1]); + + for(i = 0; i < nthreads; i++) { + if (i < nthreads - 1) { + pthread_join(tctx_pool[i].id, NULL); + } + dpu_xccl_free_team(&xccl_glob, &tctx_pool[i].comm); +// printf("Thread %d joined!\n", i); + } + + dpu_xccl_finalize(&xccl_glob); + return 0; +} diff --git a/contrib/dpu_daemon/host_channel.c b/contrib/dpu_daemon/host_channel.c new file mode 100644 index 0000000..0e28e75 --- /dev/null +++ b/contrib/dpu_daemon/host_channel.c @@ -0,0 +1,647 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include "host_channel.h" +#include +#include + +static int _dpu_host_to_ip(dpu_hc_t *hc) +{ +// printf ("%s\n", __FUNCTION__); + struct hostent *he; + struct in_addr **addr_list; + int i; + + hc->hname = calloc(1, 100 * sizeof(char)); + hc->ip = malloc(100 * sizeof(char)); + + int ret = gethostname(hc->hname, 100); + if (ret) { + return 1; + } + + if ( (he = gethostbyname( hc->hname ) ) == NULL) + { + // get the host info + herror("gethostbyname"); + return 1; + } + + addr_list = (struct in_addr **) he->h_addr_list; + for(i = 0; addr_list[i] != NULL; i++) + { + //Return the first one; + strcpy(hc->ip , inet_ntoa(*addr_list[i]) ); + return XCCL_OK; + } + return XCCL_ERR_NO_MESSAGE; +} + +static int _dpu_listen(dpu_hc_t *hc) +{ + struct sockaddr_in serv_addr; + + if(_dpu_host_to_ip(hc)) { + return XCCL_ERR_NO_MESSAGE; + } + + hc->port = DEFAULT_PORT; + /* TODO: if envar(port) - replace */ + + /* creates an UN-named socket inside the kernel and returns + * an integer known as socket descriptor + * This function takes domain/family as its first argument. + * For Internet family of IPv4 addresses we use AF_INET + */ + hc->listenfd = socket(AF_INET, SOCK_STREAM, 0); + if (0 > hc->listenfd) { + fprintf(stderr, "socket() failed (%s)\n", strerror(errno)); + goto err_ip; + } + memset(&serv_addr, 0, sizeof(serv_addr)); + + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(hc->port); + + /* The call to the function "bind()" assigns the details specified + * in the structure 『serv_addr' to the socket created in the step above + */ + if (0 > bind(hc->listenfd, (struct sockaddr*)&serv_addr, + sizeof(serv_addr))) { + fprintf(stderr, "Failed to bind() (%s)\n", strerror(errno)); + goto err_sock; + } + + /* The call to the function "listen()" with second argument as 10 specifies + * maximum number of client connections that server will queue for this listening + * socket. + */ + if (0 > listen(hc->listenfd, 10)) { + fprintf(stderr, "listen() failed (%s)\n", strerror(errno)); + goto err_sock; + } + + return XCCL_OK; +err_sock: + close(hc->listenfd); +err_ip: + free(hc->ip); + free(hc->hname); + return XCCL_ERR_NO_MESSAGE; +} + +static int _dpu_listen_cleanup(dpu_hc_t *hc) +{ + close(hc->listenfd); + free(hc->ip); + free(hc->hname); +} + +static void tag_recv_cb (void *request, ucs_status_t status, + const ucp_tag_recv_info_t *info, void *user_data) +{ + dpu_req_t *ctx = user_data; + ctx->complete = 1; +} + +static void send_cb(void *request, ucs_status_t status, void *user_data) +{ + dpu_req_t *ctx = user_data; + ctx->complete = 1; +} + +static void err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) +{ + printf ("error handling callback was invoked with status %d (%s)\n", + status, ucs_status_string(status)); +} + +static int _dpu_ucx_init(dpu_hc_t *hc) +{ + ucp_params_t ucp_params; + ucs_status_t status; + ucp_worker_params_t worker_params; + int ret = SUCCESS; + +// printf ("%s\n", __FUNCTION__); + + memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; + ucp_params.features = UCP_FEATURE_TAG | + UCP_FEATURE_RMA; +/* + UCP_FEATURE_AMO64 | + UCP_FEATURE_AMO32; +*/ + status = ucp_init(&ucp_params, NULL, &hc->ucp_ctx); + if (status != UCS_OK) { + fprintf(stderr, "failed to ucp_init(%s)\n", ucs_status_string(status)); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + + status = ucp_worker_create(hc->ucp_ctx, &worker_params, &hc->ucp_worker); + if (status != UCS_OK) { + fprintf(stderr, "failed to ucp_worker_create (%s)\n", ucs_status_string(status)); + ret = XCCL_ERR_NO_MESSAGE; + goto err_cleanup; + } + + hc->worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS | + UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS; + hc->worker_attr.address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY; + status = ucp_worker_query (hc->ucp_worker, &hc->worker_attr); + if (UCS_OK != status) { + ret = XCCL_ERR_NO_MESSAGE; + goto err_worker; + } + + return ret; +err_worker: + ucp_worker_destroy(hc->ucp_worker); +err_cleanup: + ucp_cleanup(hc->ucp_ctx); +err: + return ret; +} + +static int _dpu_ucx_fini(dpu_hc_t *hc){ + ucp_worker_release_address(hc->ucp_worker, hc->worker_attr.address); + ucp_worker_destroy(hc->ucp_worker); + ucp_cleanup(hc->ucp_ctx); +} + + +static int _dpu_hc_buffer_alloc(dpu_hc_t *hc, dpu_mem_t *mem, size_t size) +{ + ucp_mem_map_params_t mem_params; + ucp_mem_attr_t mem_attr; + ucs_status_t status; + int ret = XCCL_OK; + + memset(mem, 0, sizeof(*mem)); + mem->base = calloc(size, sizeof(char)); + memset(&mem_params, 0, sizeof(ucp_mem_map_params_t)); + + mem_params.address = mem->base; + mem_params.length = size; + + mem_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_FLAGS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_ADDRESS; + + status = ucp_mem_map(hc->ucp_ctx, &mem_params, &mem->memh); + if (status != UCS_OK) { + fprintf(stderr, "failed to ucp_mem_map (%s)\n", ucs_status_string(status)); + ret = XCCL_ERR_NO_MESSAGE; + goto out; + } + + mem_attr.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | + UCP_MEM_ATTR_FIELD_LENGTH; + + status = ucp_mem_query(mem->memh, &mem_attr); + if (status != UCS_OK) { + fprintf(stderr, "failed to ucp_mem_query (%s)\n", ucs_status_string(status)); + ret = XCCL_ERR_NO_MESSAGE; + goto err_map; + } + assert(mem_attr.length == size); + assert(mem_attr.address == mem->base); + + status = ucp_rkey_pack(hc->ucp_ctx, mem->memh, + &mem->rkey.rkey_addr, + &mem->rkey.rkey_addr_len); + if (status != UCS_OK) { + fprintf(stderr, "failed to ucp_rkey_pack (%s)\n", ucs_status_string(status)); + ret = XCCL_ERR_NO_MESSAGE; + goto err_map; + } + + goto out; +err_map: + ucp_mem_unmap(hc->ucp_ctx, mem->memh); +err_calloc: + free(mem->base); +out: + return ret; +} + +static int _dpu_hc_buffer_free(dpu_hc_t *hc, dpu_mem_t *mem) +{ + ucp_rkey_buffer_release(mem->rkey.rkey_addr); + ucp_mem_unmap(hc->ucp_ctx, mem->memh); + free(mem->base); +} + + +static size_t _dpu_set_buffer_size(char *_env) +{ + char *env = getenv(_env); + return env != NULL ? atol(env) : DATA_BUFFER_SIZE; +} + +int dpu_hc_init(dpu_hc_t *hc) +{ + int ret = XCCL_OK; + + memset(hc, 0, sizeof(*hc)); + + /* Start listening */ + ret = _dpu_listen(hc); + if (ret) { + goto out; + } + + /* init ucx objects */ + ret = _dpu_ucx_init(hc); + if (ret) { + goto err_ip; + } + + /* set buffer size */ + hc->data_buffer_size = _dpu_set_buffer_size("DPU_DATA_BUFFER_SIZE"); + + ret = _dpu_hc_buffer_alloc(hc, &hc->mem_segs.put, DATA_BUFFER_SIZE); + if (ret) { + goto err_ucx; + } + ret = _dpu_hc_buffer_alloc(hc, &hc->mem_segs.get, DATA_BUFFER_SIZE); + if (ret) { + goto err_put; + } + ret = _dpu_hc_buffer_alloc(hc, &hc->mem_segs.sync, sizeof(dpu_sync_t)); + if (ret) { + goto err_get; + } + + goto out; +err_get: + _dpu_hc_buffer_free(hc, &hc->mem_segs.get); +err_put: + _dpu_hc_buffer_free(hc, &hc->mem_segs.put); +err_ucx: + _dpu_ucx_fini(hc); +err_ip: + _dpu_listen_cleanup(hc); +out: + return ret; +} + +static ucs_status_t _dpu_ep_create (dpu_hc_t *hc, void *rem_worker_addr) +{ + ucs_status_t status; + ucp_ep_params_t ep_params; + + ep_params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | + UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | + UCP_EP_PARAM_FIELD_ERR_HANDLER | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; + ep_params.err_handler.cb = err_cb; + ep_params.address = rem_worker_addr; + + status = ucp_ep_create(hc->ucp_worker, &ep_params, &hc->host_ep); + if (status != UCS_OK) { + fprintf(stderr, "failed to create an endpoint on the dpu (%s)\n", + ucs_status_string(status)); + return XCCL_ERR_NO_MESSAGE; + } + + return XCCL_OK; +} + +static int _dpu_ep_close(dpu_hc_t *hc) +{ + ucp_request_param_t param; + ucs_status_t status; + void *close_req; + int ret = XCCL_OK; + + param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; + param.flags = UCP_EP_CLOSE_FLAG_FORCE; + close_req = ucp_ep_close_nbx(hc->host_ep, ¶m); + if (UCS_PTR_IS_PTR(close_req)) { + do { + ucp_worker_progress(hc->ucp_worker); + status = ucp_request_check_status(close_req); + } while (status == UCS_INPROGRESS); + + ucp_request_free(close_req); + } else if (UCS_PTR_STATUS(close_req) != UCS_OK) { + fprintf(stderr, "failed to close ep %p\n", (void *)hc->host_ep); + ret = XCCL_ERR_NO_MESSAGE; + } + return ret; +} + + +static ucs_status_t _dpu_request_wait(ucp_worker_h ucp_worker, void *request, + dpu_req_t *req_ctx) +{ +// printf ("%s\n", __FUNCTION__); + ucs_status_t status; + + /* immediate completion */ + if (request == NULL) { + return UCS_OK; + } + + if (UCS_PTR_IS_ERR(request)) { + return UCS_PTR_STATUS(request); + } + + while (req_ctx->complete == 0) { +// printf ("ucp_worker_progress()\n"); +// sleep(1); + ucp_worker_progress(ucp_worker); + } + status = ucp_request_check_status(request); + + ucp_request_free(request); + + return status; +} + +static int _dpu_request_finalize (ucp_worker_h ucp_worker, dpu_req_t *request, + dpu_req_t *req_ctx) +{ +// printf ("%s\n", __FUNCTION__); + ucs_status_t status; + int ret = SUCCESS; + + status = _dpu_request_wait(ucp_worker, request, req_ctx); + if (status != UCS_OK) { + fprintf (stderr, "unable to recv UCX message (%s)\n", ucs_status_string(status)); + return -1; + } + + /* reset req_ctx */ + req_ctx->complete = 0; + + return ret; +} + +static int _dpu_rmem_setup(dpu_hc_t *hc) +{ +// printf ("%s\n", __FUNCTION__); + + int i, ret = SUCCESS; + ucp_request_param_t param; + void *request; + dpu_req_t req_ctx; + size_t rkeys_total_len = 0, rkey_lens[3]; + uint64_t seg_base_addrs[3]; + char *rkeys = NULL, *rkey_p; + + req_ctx.complete = 0; + + /* XXX */ +// ucp_worker_print_info(hc->ucp_worker, stderr); + + /* recv rkey len & address */ /* recv rkey len & address */ + param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA; + param.user_data = &req_ctx; + param.cb.recv = tag_recv_cb; + param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA; + param.user_data = &req_ctx; + param.cb.recv = tag_recv_cb; + request = ucp_tag_recv_nbx(hc->ucp_worker, &hc->sync_addr, sizeof(uint64_t), + EXCHANGE_ADDR_TAG, (uint64_t)-1, ¶m); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + request = ucp_tag_recv_nbx(hc->ucp_worker, &rkey_lens[0], sizeof(size_t), + EXCHANGE_LENGTH_TAG, (uint64_t)-1, ¶m); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + rkeys = calloc(1, rkey_lens[0]); + request = ucp_tag_recv_nbx(hc->ucp_worker, rkeys, rkey_lens[0], EXCHANGE_RKEY_TAG, + (uint64_t)-1, ¶m); + + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + ucs_status_t status = ucp_ep_rkey_unpack(hc->host_ep, rkeys, &hc->sync_rkey); + if (UCS_OK != status) { + fprintf(stderr, "failed to ucp_ep_rkey_unpack (%s)\n", ucs_status_string(status)); + } + free(rkeys); + + /* send rkey lens & addresses */ + param.cb.send = send_cb; + + /* compute total len */ + for (i = 0; i < 3; i++) { + rkey_lens[i] = hc->mem_segs_array[i].rkey.rkey_addr_len; + seg_base_addrs[i] = (uint64_t)hc->mem_segs_array[i].base; + rkeys_total_len += rkey_lens[i]; +// fprintf (stdout, "rkey_total_len = %lu, rkey_lens[i] = %lu\n", +// rkeys_total_len, rkey_lens[i]); + } + + rkey_p = rkeys = calloc(1, rkeys_total_len); + + /* send rkey_lens */ + request = ucp_tag_send_nbx(hc->host_ep, rkey_lens, 3*sizeof(size_t), + EXCHANGE_LENGTH_TAG, ¶m); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + request = ucp_tag_send_nbx(hc->host_ep, seg_base_addrs, 3*sizeof(uint64_t), + EXCHANGE_ADDR_TAG, ¶m); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + /* send rkeys */ + for (i = 0; i < 3; i++) { + memcpy(rkey_p, hc->mem_segs_array[i].rkey.rkey_addr, rkey_lens[i]); + rkey_p+=rkey_lens[i]; + } + + request = ucp_tag_send_nbx(hc->host_ep, rkeys, rkeys_total_len, EXCHANGE_RKEY_TAG, ¶m); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + goto err; + } + + return SUCCESS; + +err: + printf ("%s ERROR!\n", __FUNCTION__); + return ret; +} + + +int dpu_hc_accept(dpu_hc_t *hc) +{ + int ret; + ucs_status_t status; + ucp_rkey_h client_rkey_h; + void *rem_worker_addr; + size_t rem_worker_addr_len; + + /* In the call to accept(), the server is put to sleep and when for an incoming + * client request, the three way TCP handshake* is complete, the function accept() + * wakes up and returns the socket descriptor representing the client socket. + */ +// fprintf (stderr, "Waiting for connection...\n"); + hc->connfd = accept(hc->listenfd, (struct sockaddr*)NULL, NULL); + if (-1 == hc->connfd) { + fprintf(stderr, "Error in accept (%s)!\n", strerror(errno)); + } +// fprintf (stderr, "Connection established\n"); + + ret = send(hc->connfd, &hc->worker_attr.address_length, sizeof(size_t), 0); + if (-1 == ret) { + fprintf(stderr, "send worker_address_length failed!\n"); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + ret = send(hc->connfd, hc->worker_attr.address, + hc->worker_attr.address_length, 0); + if (-1 == ret) { + fprintf(stderr, "send worker_address failed!\n"); + fprintf(stderr, "mmap_buffer failed!\n"); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + ret = recv(hc->connfd, &rem_worker_addr_len, sizeof(size_t), MSG_WAITALL); + if (-1 == ret) { + fprintf(stderr, "recv address_length failed!\n"); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + rem_worker_addr = calloc(1, rem_worker_addr_len); + + ret = recv(hc->connfd, rem_worker_addr, rem_worker_addr_len, MSG_WAITALL); + if (-1 == ret) { + fprintf(stderr, "recv worker address failed!\n"); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + if (ret = _dpu_ep_create(hc, rem_worker_addr)) { + fprintf(stderr, "dpu_create_ep failed!\n"); + ret = XCCL_ERR_NO_MESSAGE; + goto err; + } + + ret = _dpu_rmem_setup(hc); + if (ret) { + fprintf(stderr, "exchange data failed!\n"); + goto err; + } + + return ret; + +err: + close(hc->connfd); + return ret; +} + +int dpu_hc_wait(dpu_hc_t *hc, unsigned int cntr) +{ + dpu_sync_t *lsync = (dpu_sync_t*)hc->mem_segs.sync.base; + + while( lsync->itt < cntr) { + ucp_worker_progress(hc->ucp_worker); + } + return 0; +} + +unsigned int dpu_hc_get_dtype(dpu_hc_t *hc) +{ + dpu_sync_t *lsync = (dpu_sync_t*)hc->mem_segs.sync.base; + return lsync->dtype; +} + +unsigned int dpu_hc_get_op(dpu_hc_t *hc) +{ + dpu_sync_t *lsync = (dpu_sync_t*)hc->mem_segs.sync.base; + return lsync->op; +} + +unsigned int dpu_hc_get_count(dpu_hc_t *hc) +{ + dpu_sync_t *lsync = (dpu_sync_t*)hc->mem_segs.sync.base; + return lsync->len; +} + +int dpu_hc_reply(dpu_hc_t *hc, unsigned int cntr) +{ + dpu_sync_t *lsync = (dpu_sync_t*)hc->mem_segs.sync.base; + uint32_t rsync; + ucp_request_param_t req_param; + void *request; + dpu_req_t req_ctx = { 0 }; + int ret; + + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE | + UCP_OP_ATTR_FIELD_USER_DATA; + req_param.datatype = ucp_dt_make_contig(1); + req_param.cb.send = send_cb; + req_param.user_data = &req_ctx; + +// static unsigned int cntr = 1; +// while( lsync->itt < cntr) { +// ucp_worker_progress(hc->ucp_worker); +// } + + request = ucp_put_nbx(hc->host_ep, &cntr, sizeof(cntr), + hc->sync_addr, hc->sync_rkey, + &req_param); + ret = _dpu_request_finalize(hc->ucp_worker, request, &req_ctx); + if (ret) { + return -1; + } +// cntr++; + return 0; +} + +#if 0 +{ +/* Work loop */ +/* TEST + * **** */ + +free(worker_attr.address); +free(rem_worker_addr); +close(connfd); + +ep_close(ucp_worker, dpu_ep); + + +printf ("END %s\n", __FUNCTION__); + +return ret; + +err: +return ret; + +} +#endif diff --git a/contrib/dpu_daemon/host_channel.h b/contrib/dpu_daemon/host_channel.h new file mode 100644 index 0000000..6b4f70c --- /dev/null +++ b/contrib/dpu_daemon/host_channel.h @@ -0,0 +1,97 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#ifndef HOST_CHANNEL_H +#define HOST_CHANNEL_H + +#define _DEFAULT_SOURCE + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +#define IP_STRING_LEN 50 +#define PORT_STRING_LEN 8 +#define SUCCESS 0 +#define ERROR 1 +#define DEFAULT_PORT 13337 + +#define DATA_BUFFER_SIZE (16*1024*1024) + +#define EXCHANGE_LENGTH_TAG 1ull +#define EXCHANGE_RKEY_TAG 2ull +#define EXCHANGE_ADDR_TAG 3ull + + +typedef struct dpu_req_s { + int complete; +} dpu_req_t; +/* sync struct type + * use it for counter, dtype, ar op, length */ +typedef struct dpu_sync_s { + unsigned int itt; + unsigned int dtype; + unsigned int op; + unsigned int len; +} dpu_sync_t; + +typedef struct dpu_rkey_s { + void *rkey_addr; + size_t rkey_addr_len; +} dpu_rkey_t; + +typedef struct dpu_mem_s { + void *base; + ucp_mem_h memh; + dpu_rkey_t rkey; +} dpu_mem_t; + +typedef struct dpu_mem_segs_s { + dpu_mem_t sync; + dpu_mem_t put; + dpu_mem_t get; +} dpu_mem_segs_t; + +typedef struct dpu_hc_s { + /* TCP/IP stuff */ + char *hname; + char *ip; + int connfd, listenfd; + uint16_t port; + /* Local UCX stuff */ + ucp_context_h ucp_ctx; + ucp_worker_h ucp_worker; + ucp_worker_attr_t worker_attr; + union { + dpu_mem_segs_t mem_segs; + dpu_mem_t mem_segs_array[3]; + }; + /* Remote UCX stuff */ + ucp_ep_h host_ep; + uint64_t sync_addr; + ucp_rkey_h sync_rkey; + + /* bufer size*/ + size_t data_buffer_size; +} dpu_hc_t; + +int dpu_hc_init(dpu_hc_t *dpu_hc); +int dpu_hc_accept(dpu_hc_t *hc); +int dpu_hc_reply(dpu_hc_t *hc, unsigned int itt); +int dpu_hc_wait(dpu_hc_t *hc, unsigned int itt); +unsigned int dpu_hc_get_count(dpu_hc_t *hc); +unsigned int dpu_hc_get_dtype(dpu_hc_t *hc); +unsigned int dpu_hc_get_op(dpu_hc_t *hc); + +#endif diff --git a/contrib/dpu_daemon/server_xccl.c b/contrib/dpu_daemon/server_xccl.c new file mode 100644 index 0000000..8537065 --- /dev/null +++ b/contrib/dpu_daemon/server_xccl.c @@ -0,0 +1,206 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include "server_xccl.h" +#include + +typedef struct xccl_test_oob_allgather_req { + xccl_ep_range_t range; + void *sbuf; + void *rbuf; + void *oob_coll_ctx; + int my_rank; + size_t msglen; + int iter; + MPI_Request reqs[2]; +} xccl_test_oob_allgather_req_t; + +static xccl_status_t oob_allgather_test(void *req) +{ + xccl_test_oob_allgather_req_t *oob_req = + (xccl_test_oob_allgather_req_t*)req; + int rank, size, sendto, recvfrom, recvdatafrom, senddatafrom, completed, probe; + char *tmpsend = NULL, *tmprecv = NULL; + size_t msglen = oob_req->msglen; + const int probe_count = 1; + MPI_Comm comm = (MPI_Comm)oob_req->oob_coll_ctx; + + if (oob_req->range.type == XCCL_EP_RANGE_UNDEFINED) { + MPI_Comm_size(comm, &size); + MPI_Comm_rank(comm, &rank); + } else { + size = oob_req->range.ep_num; + rank = oob_req->my_rank; + } + if (oob_req->iter == 0) { + tmprecv = (char*) oob_req->rbuf + (ptrdiff_t)rank * (ptrdiff_t)msglen; + memcpy(tmprecv, oob_req->sbuf, msglen); + } + sendto = (rank + 1) % size; + recvfrom = (rank - 1 + size) % size; + if (oob_req->range.type != XCCL_EP_RANGE_UNDEFINED) { + sendto = xccl_range_to_rank(oob_req->range, sendto); + recvfrom = xccl_range_to_rank(oob_req->range, recvfrom); + } + for (; oob_req->iter < size - 1; oob_req->iter++) { + if (oob_req->iter > 0) { + probe = 0; + do { + MPI_Testall(2, oob_req->reqs, &completed, MPI_STATUS_IGNORE); + probe++; + } while (!completed && probe < probe_count); + if (!completed) { + return XCCL_INPROGRESS; + } + } + recvdatafrom = (rank - oob_req->iter - 1 + size) % size; + senddatafrom = (rank - oob_req->iter + size) % size; + tmprecv = (char*)oob_req->rbuf + (ptrdiff_t)recvdatafrom * (ptrdiff_t)msglen; + tmpsend = (char*)oob_req->rbuf + (ptrdiff_t)senddatafrom * (ptrdiff_t)msglen; + MPI_Isend(tmpsend, msglen, MPI_BYTE, sendto, 2703, + comm, &oob_req->reqs[0]); + MPI_Irecv(tmprecv, msglen, MPI_BYTE, recvfrom, 2703, + comm, &oob_req->reqs[1]); + } + probe = 0; + do { + MPI_Testall(2, oob_req->reqs, &completed, MPI_STATUS_IGNORE); + probe++; + } while (!completed && probe < probe_count); + if (!completed) { + return XCCL_INPROGRESS; + } + return XCCL_OK; +} + +static xccl_status_t oob_allgather_free(void *req) +{ + free(req); + return XCCL_OK; +} + +static xccl_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen, + int my_rank, xccl_ep_range_t range, + void *oob_coll_ctx, void **req) +{ + xccl_test_oob_allgather_req_t *oob_req = malloc(sizeof(*oob_req)); + oob_req->sbuf = sbuf; + oob_req->rbuf = rbuf; + oob_req->msglen = msglen; + oob_req->range = range; + oob_req->oob_coll_ctx = oob_coll_ctx; + oob_req->my_rank = my_rank; + oob_req->iter = 0; + *req = oob_req; + return oob_allgather_test(*req); +} + +int xccl_mpi_create_team_nb(dpu_xccl_comm_t *comm) { + /* Create XCCL TEAM for comm world */ + xccl_team_params_t team_params = { + .field_mask = XCCL_TEAM_PARAM_FIELD_EP_RANGE | + XCCL_TEAM_PARAM_FIELD_OOB, + .range = { + .type = XCCL_EP_RANGE_STRIDED, + .strided.start = 0, + .strided.stride = 1 + }, + + .oob = { + .allgather = oob_allgather, + .req_test = oob_allgather_test, + .req_free = oob_allgather_free, + .coll_context = (void*)MPI_COMM_WORLD, + .rank = comm->g->rank, + .size = comm->g->size + } + }; + + XCCL_CHECK(xccl_team_create_post(comm->ctx, &team_params, &comm->team)); + return 0; +} + +int xccl_mpi_create_team(dpu_xccl_comm_t *comm) { + xccl_mpi_create_team_nb(comm); + while (XCCL_INPROGRESS == xccl_team_create_test(comm->team)) {;}; +} + +int dpu_xccl_init(int argc, char **argv, dpu_xccl_global_t *g) +{ + char *var; + xccl_tl_id_t *tl_ids; + unsigned tl_count; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &g->rank); + MPI_Comm_size(MPI_COMM_WORLD, &g->size); + + /* Init xccl library */ + var = getenv(DPU_XCCL_TLS); + if (var) { + g->tls = xccl_tls_str_to_bitmap(var); + } + else { + g->tls = XCCL_TL_ALL; + } + + xccl_lib_params_t lib_params = { + .field_mask = XCCL_LIB_PARAM_FIELD_TEAM_USAGE | + XCCL_LIB_PARAM_FIELD_COLL_TYPES, + .team_usage = XCCL_LIB_PARAMS_TEAM_USAGE_SW_COLLECTIVES | + XCCL_LIB_PARAMS_TEAM_USAGE_HW_COLLECTIVES, + /* TODO: support more collectives */ + .coll_types = XCCL_COLL_CAP_ALLREDUCE, + }; + + XCCL_CHECK(xccl_lib_init(&lib_params, NULL, &g->lib)); + XCCL_CHECK(xccl_get_tl_list(g->lib, &tl_ids, &tl_count)); + xccl_free_tl_list(tl_ids); + return XCCL_OK; +} + +int dpu_xccl_alloc_team(dpu_xccl_global_t *g, dpu_xccl_comm_t *comm) +{ + /* Init xccl context for a specified XCCL_TEST_TLS */ + xccl_context_params_t ctx_params = { + .field_mask = XCCL_CONTEXT_PARAM_FIELD_THREAD_MODE | + XCCL_CONTEXT_PARAM_FIELD_OOB | + XCCL_CONTEXT_PARAM_FIELD_TEAM_COMPLETION_TYPE | + XCCL_CONTEXT_PARAM_FIELD_TLS, + .thread_mode = XCCL_THREAD_MODE_SINGLE, + .completion_type = XCCL_TEAM_COMPLETION_TYPE_BLOCKING, + .oob = { + .allgather = oob_allgather, + .req_test = oob_allgather_test, + .req_free = oob_allgather_free, + .coll_context = (void*)MPI_COMM_WORLD, + .rank = g->rank, + .size = g->size + }, + .tls = g->tls, + }; + xccl_context_config_t *ctx_config; + XCCL_CHECK(xccl_context_config_read(g->lib, NULL, NULL, &ctx_config)); + XCCL_CHECK(xccl_context_create(g->lib, &ctx_params, ctx_config, &comm->ctx)); + xccl_context_config_release(ctx_config); + comm->g = g; + xccl_mpi_create_team(comm); + return XCCL_OK; +} + +int dpu_xccl_free_team(dpu_xccl_global_t *g, dpu_xccl_comm_t *team) +{ + xccl_team_destroy(team->team); + xccl_context_destroy(team->ctx); +} + +void dpu_xccl_finalize(dpu_xccl_global_t *g) { + xccl_lib_cleanup(g->lib); + MPI_Finalize(); +} + +void dpu_xccl_progress(dpu_xccl_comm_t *comm) +{ + xccl_context_progress(comm->ctx); +} diff --git a/contrib/dpu_daemon/server_xccl.h b/contrib/dpu_daemon/server_xccl.h new file mode 100644 index 0000000..b02b9ac --- /dev/null +++ b/contrib/dpu_daemon/server_xccl.h @@ -0,0 +1,43 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#ifndef TEST_MPI_H +#define TEST_MPI_H + +#include +#include +#include +#include +#include +#include + +#define DPU_XCCL_TLS "DPU_XCCL_TLS" + +#define STR(x) # x +#define XCCL_CHECK(_call) if (XCCL_OK != (_call)) { \ + fprintf(stderr, "*** XCCL TEST FAIL: %s\n", STR(_call)); \ + MPI_Abort(MPI_COMM_WORLD, -1); \ + } + +typedef struct { + xccl_team_h xccl_world_team; + xccl_lib_h lib; + int rank, size; + uint64_t tls; +} dpu_xccl_global_t; + +typedef struct { + dpu_xccl_global_t *g; + xccl_context_h ctx; + xccl_team_h team; +} dpu_xccl_comm_t; + +int dpu_xccl_init(int argc, char **argv, dpu_xccl_global_t *g); +int dpu_xccl_alloc_team(dpu_xccl_global_t *g, dpu_xccl_comm_t *team); +int dpu_xccl_free_team(dpu_xccl_global_t *g, dpu_xccl_comm_t *ctx); +void dpu_xccl_finalize(dpu_xccl_global_t *g); +void dpu_xccl_progress(dpu_xccl_comm_t *team); + +#endif diff --git a/src/Makefile.am b/src/Makefile.am index fb5626a..60b0e1c 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -19,6 +19,9 @@ endif if HAVE_NCCL SUBDIRS += team_lib/nccl endif +if HAVE_DPU +SUBDIRS += team_lib/dpu +endif if HAVE_CUDA SUBDIRS += utils/cuda endif diff --git a/src/api/xccl_tls.h b/src/api/xccl_tls.h index 6645e22..214f51a 100644 --- a/src/api/xccl_tls.h +++ b/src/api/xccl_tls.h @@ -18,6 +18,7 @@ typedef enum xccl_tl_id { XCCL_TL_SHMSEG = UCS_BIT(4), XCCL_TL_MRAIL = UCS_BIT(5), XCCL_TL_NCCL = UCS_BIT(6), + XCCL_TL_DPU = UCS_BIT(7), XCCL_TL_LAST, XCCL_TL_ALL = (XCCL_TL_LAST << 1) - 3 } xccl_tl_id_t; @@ -40,6 +41,8 @@ const char* xccl_tl_str(xccl_tl_id_t tl_id) return "shmseg"; case XCCL_TL_NCCL: return "nccl"; + case XCCL_TL_DPU: + return "dpu"; default: break; } @@ -59,7 +62,7 @@ xccl_tl_id_t xccl_tls_str_to_bitmap(const char *tls_str) for (i = 1; i < XCCL_TL_LAST; i = i << 1) { if (strstr(tls_str, xccl_tl_str((xccl_tl_id_t)i))) { tls = (xccl_tl_id_t)(tls | i); - } + } } return tls; diff --git a/src/team_lib/dpu/Makefile.am b/src/team_lib/dpu/Makefile.am new file mode 100644 index 0000000..446911f --- /dev/null +++ b/src/team_lib/dpu/Makefile.am @@ -0,0 +1,24 @@ +# +# Copyright (c) 2020 Mellanox Technologies. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +if HAVE_DPU +sources = \ + xccl_dpu_lib.c \ + xccl_dpu_lib.h + +component_noinst = +component_install = xccl_team_lib_dpu.la + +xccl_team_lib_dpu_la_SOURCES =$(sources) +xccl_team_lib_dpu_la_CPPFLAGS = $(AM_CPPFLAGS) $(CPPFLAGS) $(UCX_CPPFLAGS) +xccl_team_lib_dpu_la_LDFLAGS = -module -avoid-version $(UCX_LDFLAGS) $(UCX_LIBADD) +xccl_team_lib_dpu_la_LIBADD = $(XCCL_TOP_BUILDDIR)/src/libxccl.la + +pkglib_LTLIBRARIES = $(component_install) +endif diff --git a/src/team_lib/dpu/xccl_dpu_lib.c b/src/team_lib/dpu/xccl_dpu_lib.c new file mode 100644 index 0000000..70da714 --- /dev/null +++ b/src/team_lib/dpu/xccl_dpu_lib.c @@ -0,0 +1,662 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include "config.h" +#include "xccl_dpu_lib.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static ucs_config_field_t xccl_team_lib_dpu_config_table[] = { + {"", "", + NULL, + ucs_offsetof(xccl_team_lib_dpu_config_t, super), + UCS_CONFIG_TYPE_TABLE(xccl_team_lib_config_table) + }, + + {NULL} +}; + +static ucs_config_field_t xccl_tl_dpu_context_config_table[] = { + {"", "", + NULL, + ucs_offsetof(xccl_tl_dpu_context_config_t, super), + UCS_CONFIG_TYPE_TABLE(xccl_tl_context_config_table) + }, + + {"SERVER_HOSTNAME", "", + "Bluefield IP address", + ucs_offsetof(xccl_tl_dpu_context_config_t, server_hname), + UCS_CONFIG_TYPE_STRING + }, + + {"SERVER_PORT", "13337", + "Bluefield DPU port", + ucs_offsetof(xccl_tl_dpu_context_config_t, server_port), + UCS_CONFIG_TYPE_UINT + }, + + {"ENABLE", "0", + "Assume server is running on BF", + ucs_offsetof(xccl_tl_dpu_context_config_t, use_dpu), + UCS_CONFIG_TYPE_UINT + }, + + {"HOST_DPU_LIST", "", + "A host-dpu list used to identify the DPU IP", + ucs_offsetof(xccl_tl_dpu_context_config_t, host_dpu_list), + UCS_CONFIG_TYPE_STRING + }, + + {NULL} +}; + +static xccl_status_t xccl_dpu_lib_open(xccl_team_lib_h self, + xccl_team_lib_config_t *config) +{ + xccl_team_lib_dpu_t *tl = ucs_derived_of(self, xccl_team_lib_dpu_t); + xccl_team_lib_dpu_config_t *cfg = ucs_derived_of(config, xccl_team_lib_dpu_config_t); + + tl->config.super.log_component.log_level = cfg->super.log_component.log_level; + sprintf(tl->config.super.log_component.name, "%s", "TEAM_DPU"); + xccl_dpu_debug("Team DPU opened"); + if (cfg->super.priority != -1) { + tl->super.priority = cfg->super.priority; + } + + return XCCL_OK; +} + +#define EXCHANGE_LENGTH_TAG 1ull +#define EXCHANGE_RKEY_TAG 2ull +#define EXCHANGE_ADDR_TAG 3ull + +static void err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) +{ + xccl_dpu_error("error handling callback was invoked with status %d (%s)\n", + status, ucs_status_string(status)); +} + +typedef enum { + UCX_REQUEST_ACTIVE, + UCX_REQUEST_DONE, +} ucx_request_status_t; + +typedef struct ucx_request { + ucx_request_status_t status; +} ucx_request_t; + +static void ucx_req_init(void* request) +{ + ucx_request_t *req = (ucx_request_t*)request; + req->status = UCX_REQUEST_ACTIVE; +} + +static void ucx_req_cleanup(void* request){ } + +static int _server_connect(char *hname, uint16_t port) +{ + int sock, n; + struct sockaddr_in addr; + struct addrinfo *res, *t; + struct addrinfo hints = { 0 }; + char service[64]; + + + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + sprintf(service, "%d", port); + n = getaddrinfo(hname, service, &hints, &res); + + if (n < 0) { + xccl_dpu_error("%s:%d: getaddrinfo(): %s for %s:%s\n", + __FILE__,__LINE__, + gai_strerror(n), hname, service); + return -1; + } + + for (t = res; t; t = t->ai_next) { + sock = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (sock >= 0) { + if (!connect(sock, t->ai_addr, t->ai_addrlen)) + break; + close(sock); + sock = -1; + } + } + freeaddrinfo(res); + return sock; +} + + +static xccl_status_t +xccl_dpu_context_create(xccl_team_lib_h lib, xccl_context_params_t *params, + xccl_tl_context_config_t *config, + xccl_tl_context_t **context) +{ + xccl_dpu_context_t *ctx = malloc(sizeof(*ctx)); + ucp_params_t ucp_params; + ucp_worker_params_t worker_params; + ucp_ep_params_t ep_params; + ucs_status_t status; + xccl_tl_dpu_context_config_t *cfg; + char hname[256]; + void *rem_worker_addr; + size_t rem_worker_addr_size; + int sockfd = 0, found = 0; + + cfg = ucs_derived_of(config, xccl_tl_dpu_context_config_t); + XCCL_CONTEXT_SUPER_INIT(ctx->super, lib, params); + +#if 0 + if (atoi(getenv("OMPI_COMM_WORLD_RANK")) == 0) { + strcpy(cfg->server_hname, "thor001"); + } + else if (atoi(getenv("OMPI_COMM_WORLD_RANK")) == 1) { + strcpy(cfg->server_hname, "thor002"); + } + else { + fprintf(stderr, "error %s", __FUNCTION__); + } +#endif + gethostname(hname, sizeof(hname) - 1); + if (cfg->use_dpu) { + char *h = calloc(1, 256); + FILE *fp = NULL; + + if (strcmp(cfg->host_dpu_list,"") != 0) { + + fp = fopen(cfg->host_dpu_list, "r"); + if (fp == NULL) { + fprintf(stderr, "Unable to open \"%s\", disabling dpu team\n", cfg->host_dpu_list); + cfg->use_dpu = 0; + found = 0; + } + else { + while (fscanf(fp,"%s", h) != EOF) { + if (strcmp(h, hname) == 0) { + found = 1; + fscanf(fp, "%s", hname); + fprintf(stderr, "DPU <%s> found!\n", hname); + break; + } + memset(h, 0, 256); + } + } + if (!found) { + cfg->use_dpu = 0; + } + } + else { + fprintf(stderr, "DPU_ENABLE set, but HOST_LIST not specified. Disabling DPU team!\n"); + cfg->use_dpu = 0; + } + free(h); + } + else { + goto err; + } + + if (!found) { + goto err; + } + + xccl_dpu_info("Connecting to %s", hname); + sockfd = _server_connect(hname, cfg->server_port); + + memset(&ucp_params, 0, sizeof(ucp_params)); + ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | + UCP_PARAM_FIELD_REQUEST_SIZE | + UCP_PARAM_FIELD_REQUEST_INIT | + UCP_PARAM_FIELD_REQUEST_CLEANUP; + ucp_params.features = UCP_FEATURE_TAG | + UCP_FEATURE_RMA; + ucp_params.request_size = sizeof(ucx_request_t); + ucp_params.request_init = ucx_req_init; + ucp_params.request_cleanup = ucx_req_cleanup; + + status = ucp_init(&ucp_params, NULL, &ctx->ucp_context); + if (status != UCS_OK) { + xccl_dpu_error("failed ucp_init(%s)\n", ucs_status_string(status)); + goto err; + } + + memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + status = ucp_worker_create(ctx->ucp_context, &worker_params, &ctx->ucp_worker); + if (status != UCS_OK) { + xccl_dpu_error("failed ucp_worker_create (%s)\n", ucs_status_string(status)); + goto err_cleanup_context; + } + + ucp_worker_attr_t attr; + attr.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS | + UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS; + attr.address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY; + ucp_worker_query(ctx->ucp_worker, &attr); + int ret; + ret = send(sockfd, &attr.address_length, sizeof(&attr.address_length), 0); + if (ret<0) { + xccl_dpu_error("send length failed"); + } + ret = send(sockfd, attr.address, attr.address_length, 0); + if (ret<0) { + xccl_dpu_error("send address failed"); + } + ret = recv(sockfd, &rem_worker_addr_size, sizeof(rem_worker_addr_size), MSG_WAITALL); + if (ret<0) { + xccl_dpu_error("recv address length failed"); + } + rem_worker_addr = malloc(rem_worker_addr_size); + ret = recv(sockfd, rem_worker_addr, rem_worker_addr_size, MSG_WAITALL); + if (ret<0) { + xccl_dpu_error("recv address failed"); + } + ep_params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | + UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | + UCP_EP_PARAM_FIELD_ERR_HANDLER | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + ep_params.err_mode = UCP_ERR_HANDLING_MODE_PEER; + ep_params.err_handler.cb = err_cb; + ep_params.address = rem_worker_addr; + + status = ucp_ep_create(ctx->ucp_worker, &ep_params, &ctx->ucp_ep); + free(attr.address); + free(rem_worker_addr); + close(sockfd); + if (status != UCS_OK) { + xccl_dpu_error("failed to connect to %s (%s)\n", + hname, ucs_status_string(status)); + goto err_cleanup_worker; + } + + *context = &ctx->super; + + xccl_dpu_debug("context created"); + return XCCL_OK; + +err_cleanup_worker: + ucp_worker_destroy(ctx->ucp_worker); +err_cleanup_context: + ucp_cleanup(ctx->ucp_context); +err: + return XCCL_ERR_NO_MESSAGE; +} + +static xccl_status_t +xccl_dpu_context_destroy(xccl_tl_context_t *context) +{ + xccl_dpu_context_t *dpu_ctx = ucs_derived_of(context, xccl_dpu_context_t); + ucp_request_param_t param; + ucs_status_t status; + void *close_req; + + param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS; + param.flags = UCP_EP_CLOSE_FLAG_FORCE; + close_req = ucp_ep_close_nbx(dpu_ctx->ucp_ep, ¶m); + if (UCS_PTR_IS_PTR(close_req)) { + do { + ucp_worker_progress(dpu_ctx->ucp_worker); + status = ucp_request_check_status(close_req); + } while (status == UCS_INPROGRESS); + ucp_request_free (close_req); + } else if (UCS_PTR_STATUS(close_req) != UCS_OK) { + xccl_dpu_error("failed to close ep %p\n", (void *)dpu_ctx->ucp_ep); + return XCCL_ERR_NO_MESSAGE; + } + ucp_worker_destroy(dpu_ctx->ucp_worker); + ucp_cleanup(dpu_ctx->ucp_context); + free(dpu_ctx); + + return XCCL_OK; +} + +static void send_handler_nbx(void *request, ucs_status_t status, + void *user_data) { + ucx_request_t *req = (ucx_request_t*)request; + req->status = UCX_REQUEST_DONE; +} + +void recv_handler_nbx(void *request, ucs_status_t status, + const ucp_tag_recv_info_t *tag_info, + void *user_data) { + ucx_request_t *req = (ucx_request_t*)request; + req->status = UCX_REQUEST_DONE; +} + +static xccl_status_t ucx_req_test(ucx_request_t **req, ucp_worker_h worker) { + if (*req == NULL) { + return XCCL_OK; + } + + if ((*req)->status == UCX_REQUEST_DONE) { + (*req)->status = UCX_REQUEST_ACTIVE; + ucp_request_free(*req); + (*req) = NULL; + return XCCL_OK; + } + ucp_worker_progress(worker); + return XCCL_INPROGRESS; +} + +static xccl_status_t ucx_req_check(ucx_request_t *req) { + if (UCS_PTR_IS_ERR(req)) { + xccl_dpu_error("failed to send/recv msg"); + return XCCL_ERR_NO_MESSAGE; + } + return XCCL_OK; +} + +static xccl_status_t +xccl_dpu_team_create_post(xccl_tl_context_t *context, + xccl_team_params_t *params, + xccl_tl_team_t **team) +{ + xccl_dpu_context_t *ctx = ucs_derived_of(context, xccl_dpu_context_t); + xccl_dpu_team_t *dpu_team = malloc(sizeof(*dpu_team)); + ucp_mem_map_params_t mmap_params; + ucp_request_param_t send_req_param, recv_req_param; + ucx_request_t *send_req[3], *recv_req[2]; + size_t rem_rkeys_lengths[3]; + uint64_t rem_addresses[3]; + void *ctrl_seg_rkey_buf; + size_t ctrl_seg_rkey_buf_size; + size_t total_rkey_size; + void *rem_rkeys; + + XCCL_TEAM_SUPER_INIT(dpu_team->super, context, params); + + dpu_team->coll_id = 1; + mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH; + mmap_params.address = (void*)&dpu_team->ctrl_seg; + mmap_params.length = sizeof(dpu_team->ctrl_seg); + + ucp_mem_map(ctx->ucp_context, &mmap_params, &dpu_team->ctrl_seg_memh); + ucp_rkey_pack(ctx->ucp_context, dpu_team->ctrl_seg_memh, + &ctrl_seg_rkey_buf, + &ctrl_seg_rkey_buf_size); + + send_req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE; + send_req_param.datatype = ucp_dt_make_contig(1); + send_req_param.cb.send = send_handler_nbx; + + recv_req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE; + recv_req_param.datatype = ucp_dt_make_contig(1); + recv_req_param.cb.recv = recv_handler_nbx; + + send_req[0] = ucp_tag_send_nbx(ctx->ucp_ep, &mmap_params.address, + sizeof(uint64_t), EXCHANGE_ADDR_TAG, + &send_req_param); + if (ucx_req_check(send_req[0]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + + send_req[1] = ucp_tag_send_nbx(ctx->ucp_ep, &ctrl_seg_rkey_buf_size, + sizeof(size_t), EXCHANGE_LENGTH_TAG, + &send_req_param); + if (ucx_req_check(send_req[1]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + send_req[2] = ucp_tag_send_nbx(ctx->ucp_ep, ctrl_seg_rkey_buf, + ctrl_seg_rkey_buf_size, EXCHANGE_RKEY_TAG, + &send_req_param); + if (ucx_req_check(send_req[2]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + recv_req[0] = ucp_tag_recv_nbx(ctx->ucp_worker, rem_rkeys_lengths, + sizeof(rem_rkeys_lengths), + EXCHANGE_LENGTH_TAG, (uint64_t)-1, + &recv_req_param); + if (ucx_req_check(recv_req[0]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + do { + ucp_worker_progress(ctx->ucp_worker); + } while((ucx_req_test(&(send_req[0]), ctx->ucp_worker) != XCCL_OK) || + (ucx_req_test(&(send_req[1]), ctx->ucp_worker) != XCCL_OK) || + (ucx_req_test(&(send_req[2]), ctx->ucp_worker) != XCCL_OK) || + (ucx_req_test(&(recv_req[0]), ctx->ucp_worker) != XCCL_OK)); + +// fprintf (stderr,"%lu ;%lu %lu\n", rem_rkeys_lengths[0], rem_rkeys_lengths[1], rem_rkeys_lengths[2]); + + ucp_rkey_buffer_release(ctrl_seg_rkey_buf); + total_rkey_size = rem_rkeys_lengths[0] + rem_rkeys_lengths[1] + rem_rkeys_lengths[2]; + rem_rkeys = malloc(total_rkey_size); + recv_req[0] = ucp_tag_recv_nbx(ctx->ucp_worker, &rem_addresses, + sizeof(rem_addresses), + EXCHANGE_ADDR_TAG, (uint64_t)-1, + &recv_req_param); + if (ucx_req_check(recv_req[0]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + + recv_req[1] = ucp_tag_recv_nbx(ctx->ucp_worker, rem_rkeys, + total_rkey_size, + EXCHANGE_RKEY_TAG, (uint64_t)-1, + &recv_req_param); + if (ucx_req_check(recv_req[1]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + do { + ucp_worker_progress(ctx->ucp_worker); + } while((ucx_req_test(&recv_req[0], ctx->ucp_worker) != XCCL_OK) || + (ucx_req_test(&recv_req[1], ctx->ucp_worker) != XCCL_OK)); + + dpu_team->rem_ctrl_seg = rem_addresses[0]; + ucp_ep_rkey_unpack(ctx->ucp_ep, rem_rkeys, &dpu_team->rem_ctrl_seg_key); + dpu_team->rem_data_in = rem_addresses[1]; + ucp_ep_rkey_unpack(ctx->ucp_ep, (void*)((ptrdiff_t)rem_rkeys + rem_rkeys_lengths[0]), + &dpu_team->rem_data_in_key); + dpu_team->rem_data_out = rem_addresses[2]; + ucp_ep_rkey_unpack(ctx->ucp_ep, (void*)((ptrdiff_t)rem_rkeys + + rem_rkeys_lengths[1] + rem_rkeys_lengths[0]), + &dpu_team->rem_data_out_key); + free(rem_rkeys); + + *team = &dpu_team->super; + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_team_create_test(xccl_tl_team_t *team) +{ + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_team_destroy(xccl_tl_team_t *team) +{ + xccl_dpu_team_t *dpu_team = ucs_derived_of(team, xccl_dpu_team_t); + xccl_dpu_context_t *dpu_ctx = ucs_derived_of(team->ctx, xccl_dpu_context_t); + dpu_sync_t hangup; + ucx_request_t *hangup_req; + ucp_request_param_t req_param; + + hangup.itt = dpu_team->coll_id; + hangup.dtype = XCCL_DT_UNSUPPORTED; + hangup.op = XCCL_OP_UNSUPPORTED; + hangup.len = 0; + + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE; + req_param.datatype = ucp_dt_make_contig(1); + req_param.cb.send = send_handler_nbx; + + hangup_req = ucp_put_nbx(dpu_ctx->ucp_ep, &hangup, sizeof(hangup), + dpu_team->rem_ctrl_seg, dpu_team->rem_ctrl_seg_key, + &req_param); + if (ucx_req_check(hangup_req) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + do { + ucp_worker_progress(dpu_ctx->ucp_worker); + } while((ucx_req_test(&(hangup_req), dpu_ctx->ucp_worker) != XCCL_OK)); + + ucp_rkey_destroy(dpu_team->rem_ctrl_seg_key); + ucp_rkey_destroy(dpu_team->rem_data_in_key); + ucp_rkey_destroy(dpu_team->rem_data_out_key); + ucp_mem_unmap(dpu_ctx->ucp_context, dpu_team->ctrl_seg_memh); + free(team); + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_collective_init(xccl_coll_op_args_t *coll_args, + xccl_tl_coll_req_t **request, + xccl_tl_team_t *team) +{ + xccl_dpu_info("Collective init"); + xccl_dpu_coll_req_t *req = (xccl_dpu_coll_req_t*)malloc(sizeof(xccl_dpu_coll_req_t)); + xccl_dpu_team_t *dpu_team = ucs_derived_of(team, xccl_dpu_team_t); + + if (req == NULL) { + return XCCL_ERR_NO_MEMORY; + } + memcpy(&req->args, coll_args, sizeof(xccl_coll_op_args_t)); + req->sync.itt = dpu_team->coll_id; + req->sync.dtype = coll_args->reduce_info.dt; + req->sync.len = coll_args->reduce_info.count; + req->sync.op = coll_args->reduce_info.op; + req->team = team; + *request = &req->super; + (*request)->lib = &xccl_team_lib_dpu.super; + dpu_team->coll_id++; + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_collective_post(xccl_tl_coll_req_t *request) +{ + xccl_dpu_coll_req_t *req = ucs_derived_of(request, xccl_dpu_coll_req_t); + xccl_dpu_team_t *dpu_team = ucs_derived_of(req->team, xccl_dpu_team_t); + xccl_dpu_context_t *dpu_ctx = ucs_derived_of(req->team->ctx, xccl_dpu_context_t); + ucp_request_param_t req_param; + ucx_request_t *send_req[2]; + + xccl_dpu_info("Collective post"); + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE; + req_param.datatype = ucp_dt_make_contig(1); + req_param.cb.send = send_handler_nbx; + + send_req[0] = ucp_put_nbx(dpu_ctx->ucp_ep, req->args.buffer_info.src_buffer, + req->args.reduce_info.count * xccl_dt_size(req->args.reduce_info.dt), + dpu_team->rem_data_in, dpu_team->rem_data_in_key, + &req_param); + if (ucx_req_check(send_req[0]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + ucp_worker_fence(dpu_ctx->ucp_worker); + send_req[1] = ucp_put_nbx(dpu_ctx->ucp_ep, &req->sync, sizeof(req->sync), + dpu_team->rem_ctrl_seg, dpu_team->rem_ctrl_seg_key, + &req_param); + if (ucx_req_check(send_req[1]) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + do { + ucp_worker_progress(dpu_ctx->ucp_worker); + } while((ucx_req_test(&(send_req[0]), dpu_ctx->ucp_worker) != XCCL_OK) || + (ucx_req_test(&(send_req[1]), dpu_ctx->ucp_worker) != XCCL_OK)); + + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_collective_wait(xccl_tl_coll_req_t *request) +{ + xccl_dpu_info("Collective wait"); + fprintf(stderr, "collective wait is not implemented"); + return XCCL_ERR_NOT_IMPLEMENTED; +} + +static xccl_status_t xccl_dpu_collective_test(xccl_tl_coll_req_t *request) +{ + xccl_dpu_coll_req_t *req = ucs_derived_of(request, xccl_dpu_coll_req_t); + xccl_dpu_team_t *dpu_team = ucs_derived_of(req->team, xccl_dpu_team_t); + xccl_dpu_context_t *dpu_ctx = ucs_derived_of(req->team->ctx, xccl_dpu_context_t); + ucp_request_param_t req_param; + ucx_request_t *recv_req; + volatile uint32_t *check_flag = dpu_team->ctrl_seg; + + if (dpu_team->coll_id != (*check_flag + 1)) { + return XCCL_INPROGRESS; + } + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_DATATYPE; + req_param.datatype = ucp_dt_make_contig(1); + req_param.cb.recv = recv_handler_nbx; + + recv_req = ucp_get_nbx(dpu_ctx->ucp_ep, req->args.buffer_info.dst_buffer, + req->args.reduce_info.count * xccl_dt_size(req->args.reduce_info.dt), + dpu_team->rem_data_out, dpu_team->rem_data_out_key, + &req_param); + if (ucx_req_check(recv_req) != XCCL_OK) { + return XCCL_ERR_NO_MESSAGE; + } + do { + ucp_worker_progress(dpu_ctx->ucp_worker); + } while((ucx_req_test(&recv_req, dpu_ctx->ucp_worker) != XCCL_OK)); + + return XCCL_OK; +} + +static xccl_status_t xccl_dpu_collective_finalize(xccl_tl_coll_req_t *request) +{ + xccl_dpu_info("Collective finalize"); + free(request); + return XCCL_OK; +} + +xccl_team_lib_dpu_t xccl_team_lib_dpu = { + .super.name = "dpu", + .super.id = XCCL_TL_DPU, + .super.priority = 90, + .super.team_lib_config = + { + .name = "DPU team library", + .prefix = "TEAM_DPU_", + .table = xccl_team_lib_dpu_config_table, + .size = sizeof(xccl_team_lib_dpu_config_t), + }, + .super.tl_context_config = { + .name = "DPU tl context", + .prefix = "TEAM_DPU_", + .table = xccl_tl_dpu_context_config_table, + .size = sizeof(xccl_tl_dpu_context_config_t), + }, + .super.params.reproducible = XCCL_REPRODUCIBILITY_MODE_NON_REPRODUCIBLE, + .super.params.thread_mode = XCCL_THREAD_MODE_SINGLE | + XCCL_THREAD_MODE_MULTIPLE, + .super.params.team_usage = XCCL_LIB_PARAMS_TEAM_USAGE_HW_COLLECTIVES, + .super.params.coll_types = XCCL_COLL_CAP_ALLREDUCE, + .super.mem_types = UCS_BIT(UCS_MEMORY_TYPE_HOST) | + UCS_BIT(UCS_MEMORY_TYPE_CUDA), + .super.ctx_create_mode = XCCL_TEAM_LIB_CONTEXT_CREATE_MODE_LOCAL, + .super.team_context_create = xccl_dpu_context_create, + .super.team_context_destroy = xccl_dpu_context_destroy, + .super.team_context_progress = NULL, + .super.team_create_post = xccl_dpu_team_create_post, + .super.team_create_test = xccl_dpu_team_create_test, + .super.team_destroy = xccl_dpu_team_destroy, + .super.team_lib_open = xccl_dpu_lib_open, + .super.collective_init = xccl_dpu_collective_init, + .super.collective_post = xccl_dpu_collective_post, + .super.collective_wait = xccl_dpu_collective_wait, + .super.collective_test = xccl_dpu_collective_test, + .super.collective_finalize = xccl_dpu_collective_finalize, + .super.global_mem_map_start = NULL, + .super.global_mem_map_test = NULL, + .super.global_mem_unmap = NULL, +}; diff --git a/src/team_lib/dpu/xccl_dpu_lib.h b/src/team_lib/dpu/xccl_dpu_lib.h new file mode 100644 index 0000000..55dc414 --- /dev/null +++ b/src/team_lib/dpu/xccl_dpu_lib.h @@ -0,0 +1,82 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#ifndef XCCL_TEAM_LIB_DPU_H_ +#define XCCL_TEAM_LIB_DPU_H_ + +#include "xccl_team_lib.h" +#include + +typedef struct xccl_team_lib_dpu_config { + xccl_team_lib_config_t super; +} xccl_team_lib_dpu_config_t; + +typedef struct xccl_tl_dpu_context_config { + xccl_tl_context_config_t super; + char *server_hname; + uint32_t server_port; + uint32_t use_dpu; + char *host_dpu_list; +} xccl_tl_dpu_context_config_t; + +typedef struct xccl_team_lib_dpu { + xccl_team_lib_t super; + xccl_team_lib_dpu_config_t config; +} xccl_team_lib_dpu_t; + +extern xccl_team_lib_dpu_t xccl_team_lib_dpu; + +#define xccl_team_dpu_log_component(_level, _fmt, ...) \ + do { \ + ucs_log_component(_level, &xccl_team_lib_dpu.config.super.log_component, _fmt, ## __VA_ARGS__); \ + } while (0) + +#define xccl_dpu_error(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_ERROR, _fmt, ## __VA_ARGS__) +#define xccl_dpu_warn(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_WARN, _fmt, ## __VA_ARGS__) +#define xccl_dpu_info(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_INFO, _fmt, ## __VA_ARGS__) +#define xccl_dpu_debug(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_DEBUG, _fmt, ## __VA_ARGS__) +#define xccl_dpu_trace(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE, _fmt, ## __VA_ARGS__) +#define xccl_dpu_trace_req(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE_REQ, _fmt, ## __VA_ARGS__) +#define xccl_dpu_trace_data(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE_DATA, _fmt, ## __VA_ARGS__) +#define xccl_dpu_trace_async(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE_ASYNC, _fmt, ## __VA_ARGS__) +#define xccl_dpu_trace_func(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE_FUNC, "%s(" _fmt ")", __FUNCTION__, ## __VA_ARGS__) +#define xccl_dpu_trace_poll(_fmt, ...) xccl_team_dpu_log_component(UCS_LOG_LEVEL_TRACE_POLL, _fmt, ## __VA_ARGS__) + + +typedef struct xccl_dpu_context { + xccl_tl_context_t super; + ucp_context_h ucp_context; + ucp_worker_h ucp_worker; + ucp_ep_h ucp_ep; +} xccl_dpu_context_t; + +typedef struct xccl_dpu_team { + xccl_tl_team_t super; + uint32_t coll_id; + uint32_t ctrl_seg[1]; + ucp_mem_h ctrl_seg_memh; + uint64_t rem_ctrl_seg; + ucp_rkey_h rem_ctrl_seg_key; + uint64_t rem_data_in; + ucp_rkey_h rem_data_in_key; + uint64_t rem_data_out; + ucp_rkey_h rem_data_out_key; +} xccl_dpu_team_t; + +typedef struct dpu_sync_t { + unsigned int itt; + unsigned int dtype; + unsigned int op; + unsigned int len; +} dpu_sync_t; + +typedef struct xccl_dpu_coll_req { + xccl_tl_coll_req_t super; + xccl_tl_team_t *team; + xccl_coll_op_args_t args; + dpu_sync_t sync; +} xccl_dpu_coll_req_t; + +#endif diff --git a/src/team_lib/hier/xccl_hier_context.c b/src/team_lib/hier/xccl_hier_context.c index 1fd4a63..744436a 100644 --- a/src/team_lib/hier/xccl_hier_context.c +++ b/src/team_lib/hier/xccl_hier_context.c @@ -145,7 +145,7 @@ xccl_status_t xccl_hier_create_context(xccl_team_lib_t *lib, ctx->local_proc.pid = xccl_local_process_info()->pid; memset(ctx->tls, 0, sizeof(ctx->tls)); *context = NULL; - + ucs_for_each_bit(tl, XCCL_TL_ALL) { ctx->tls[tl].enabled = 1; } @@ -154,6 +154,7 @@ xccl_status_t xccl_hier_create_context(xccl_team_lib_t *lib, /* Disable recursion */ ctx->tls[ucs_ilog2(XCCL_TL_HIER)].enabled = 0; ctx->tls[ucs_ilog2(XCCL_TL_MRAIL)].enabled = 0; + ctx->tls[ucs_ilog2(XCCL_TL_DPU)].enabled = 0; ctx->tls[ucs_ilog2(XCCL_TL_SHARP)].enabled = hier_cfg->enable_sharp; ctx->tls[ucs_ilog2(XCCL_TL_SHMSEG)].enabled = hier_cfg->enable_shmseg; ctx->tls[ucs_ilog2(XCCL_TL_HMC)].enabled = hier_cfg->enable_hmc; From 2e158cb01ff3c0b9e83e7c06df3cd9cdece37a37 Mon Sep 17 00:00:00 2001 From: Tomislavj Janjusic Date: Thu, 7 Jan 2021 11:42:23 -0800 Subject: [PATCH 2/2] Adding arch detection in xccl_ucs Signed-off-by: Tomislavj Janjusic Co-authored-by: Artem Polyakov Co-authored-by: Sergey Lebedev --- src/core/xccl_ucs.h | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/core/xccl_ucs.h b/src/core/xccl_ucs.h index 224f793..c911cb8 100644 --- a/src/core/xccl_ucs.h +++ b/src/core/xccl_ucs.h @@ -8,7 +8,16 @@ #define XCCL_UCS_H_ #include -#include + +#if defined(__x86_64__) +# include +#elif defined(__aarch64__) +# include +#elif defined(__powerpc64__) +# include +#else +# error "Unsupported architecture" +#endif #define ucs_ilog2(_n) \ ( \