diff --git a/examples/ucp_client_server_multi_dev.c b/examples/ucp_client_server_multi_dev.c index b69e07f9295..9bb463e8da3 100644 --- a/examples/ucp_client_server_multi_dev.c +++ b/examples/ucp_client_server_multi_dev.c @@ -74,6 +74,7 @@ typedef struct ucx_server_ctx { volatile ucp_conn_request_h conn_request; ucp_listener_h listener; uint64_t client_id; + int dev_id; } ucx_server_ctx_t; @@ -217,7 +218,8 @@ static void err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) /** * Set an address for the server to listen on - INADDR_ANY on a well known port. */ -void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr) +void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr, + int dev_id) { struct sockaddr_in *sa_in; struct sockaddr_in6 *sa_in6; @@ -234,7 +236,7 @@ void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr) sa_in->sin_addr.s_addr = INADDR_ANY; } sa_in->sin_family = AF_INET; - sa_in->sin_port = htons(server_port); + sa_in->sin_port = htons(server_port + dev_id); break; case AF_INET6: sa_in6 = (struct sockaddr_in6*)saddr; @@ -244,7 +246,7 @@ void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr) sa_in6->sin6_addr = in6addr_any; } sa_in6->sin6_family = AF_INET6; - sa_in6->sin6_port = htons(server_port); + sa_in6->sin6_port = htons(server_port + dev_id); break; default: fprintf(stderr, "Invalid address family"); @@ -276,15 +278,12 @@ static int client_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, { /* Client must send one separate connection request to the server for each * clientGPU-ServerGPU pair of workers. The requests must be sent in the - * order of the client-server GPU id pairs (see server_create_eps). - * Note that all the requests are sent to one signle listener on the - * server. That is, we don't need one listener per worker. */ + * order of the client GPU ids so the server knows which clinet GPU the incoming + * connection request is for (see server_create_eps). */ ucp_ep_params_t ep_params; struct sockaddr_storage connect_addr; ucs_status_t status; - set_sock_addr(address_str, &connect_addr); - /* * Endpoint field mask bits: * UCP_EP_PARAM_FIELD_FLAGS - Use the value of the 'flags' field. @@ -308,15 +307,18 @@ static int client_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, ep_params.err_handler.arg = NULL; ep_params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER | UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID; - ep_params.sockaddr.addr = (struct sockaddr*)&connect_addr; - ep_params.sockaddr.addrlen = sizeof(connect_addr); for (int cdev = 0; cdev < dev_count; cdev++) { for (int sdev = 0; sdev < dev_count; sdev++) { + // Set address with the correct port for each server device + set_sock_addr(address_str, &connect_addr, sdev); + ep_params.sockaddr.addr = (struct sockaddr*)&connect_addr; + ep_params.sockaddr.addrlen = sizeof(connect_addr); status = ucp_ep_create(dev_ucp_contexts[cdev].ucp_worker, &ep_params, &dev_ucp_contexts[cdev].ucp_eps[sdev]); if (status != UCS_OK) { - fprintf(stderr, "failed to connect to %s (%s)\n", address_str, + fprintf(stderr, "failed to connect to %s port %d (%s)\n", + address_str, server_port + sdev, ucs_status_string(status)); close_eps(dev_ucp_contexts, dev_count); return -1; @@ -905,7 +907,8 @@ static ucs_status_t server_create_ep(ucp_worker_h ucp_worker, } static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, - int dev_count, ucx_server_ctx_t *context) + ucx_server_ctx_t *server_contexts, + int dev_count) { /* Creating server-side eps. The eps are created upon receiving connection * requests initiated by the client. The client must initiate one request @@ -916,15 +919,12 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, * For each connection request, we need to know: * 1. the client-side UCP worker GPU id associated with the request, * 2. the server-side UCP worker GPU id that the request wants to target - * We rely on a contract between the client and server: the client - * issues the requests in the order of the client-server GPU id pairs. - * That is, the first request is for client_gpu_0 to server_gpu_0, - * the second for client_gpu_0 to server_gpu_1, and so on. Thus, we can use - * a pair of dev_id counters on the server side to map each request to its - * corresponding client-server GPU ids pair. - * Note that we assume the client and server use the same number of GPUs. - * Otherwise, they need to exchange an initial message to let each other - * know about the number of GPUs they use. */ + * For 1, we rely on the fact that the client sends its requests in the + * order of its GPU ids. For 2, we already know it because we have a + * separate listener per server GPU. Note that we assume the client and + * server use the same number of GPUs. Otherwise, they need to exchange an + * initial message to let each other know about the number of GPUs they use. + */ ucs_status_t status; for (int cdev = 0; cdev < dev_count; cdev++) { /* server GPUs */ for (int sdev = 0; sdev < dev_count; sdev++) { /* client GPUs */ @@ -933,12 +933,12 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, * which the server's connection request callback is invoked, * i.e. several clients are trying to connect in parallel, the * server will handle only the first one and reject the rest. */ - while (context->conn_request == NULL) { - ucp_worker_progress(dev_ucp_contexts[0].ucp_worker); + while (server_contexts[sdev].conn_request == NULL) { + ucp_worker_progress(dev_ucp_contexts[sdev].ucp_worker); } status = server_create_ep(dev_ucp_contexts[sdev].ucp_worker, - context->conn_request, + server_contexts[sdev].conn_request, &dev_ucp_contexts[sdev].ucp_eps[cdev]); if (status != UCS_OK) { close_eps(dev_ucp_contexts, dev_count); @@ -950,8 +950,9 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, dev_ucp_contexts[sdev].ep_count++; /* Now we are ready to accept the next request, but only - * for the rest of the GPUs from the same client. */ - context->conn_request = NULL; + * for the rest of the GPUs from the same client. So, do not + * reset client_id. */ + server_contexts[sdev].conn_request = NULL; } } @@ -963,7 +964,7 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, */ static ucs_status_t start_server(ucp_worker_h ucp_worker, ucx_server_ctx_t *context, - ucp_listener_h *listener_p, const char *address_str) + const char *address_str, int dev_id) { struct sockaddr_storage listen_addr; ucp_listener_params_t params; @@ -972,37 +973,39 @@ start_server(ucp_worker_h ucp_worker, ucx_server_ctx_t *context, char ip_str[IP_STRING_LEN]; char port_str[PORT_STRING_LEN]; - set_sock_addr(address_str, &listen_addr); + set_sock_addr(address_str, &listen_addr, dev_id); - params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | - UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; - params.sockaddr.addr = (const struct sockaddr*)&listen_addr; - params.sockaddr.addrlen = sizeof(listen_addr); - params.conn_handler.cb = server_conn_handle_cb; - params.conn_handler.arg = context; + params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | + UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; + params.sockaddr.addr = (const struct sockaddr*)&listen_addr; + params.sockaddr.addrlen = sizeof(listen_addr); + params.conn_handler.cb = server_conn_handle_cb; + params.conn_handler.arg = context; /* Create a listener on the server side to listen on the given address.*/ - status = ucp_listener_create(ucp_worker, ¶ms, listener_p); + status = ucp_listener_create(ucp_worker, ¶ms, &context->listener); if (status != UCS_OK) { - fprintf(stderr, "failed to listen (%s)\n", ucs_status_string(status)); + fprintf(stderr, "failed to listen for device %d (%s)\n", + dev_id, ucs_status_string(status)); goto out; } /* Query the created listener to get the port it is listening on. */ attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; - status = ucp_listener_query(*listener_p, &attr); + status = ucp_listener_query(context->listener, &attr); if (status != UCS_OK) { - fprintf(stderr, "failed to query the listener (%s)\n", - ucs_status_string(status)); - ucp_listener_destroy(*listener_p); + fprintf(stderr, "failed to query the listener for device %d (%s)\n", + dev_id, ucs_status_string(status)); + ucp_listener_destroy(context->listener); goto out; } - fprintf(stderr, "server is listening on IP %s port %s\n", + fprintf(stderr, "server is listening on IP %s port %s for device %d\n", sockaddr_get_ip_str(&attr.sockaddr, ip_str, IP_STRING_LEN), - sockaddr_get_port_str(&attr.sockaddr, port_str, PORT_STRING_LEN)); + sockaddr_get_port_str(&attr.sockaddr, port_str, PORT_STRING_LEN), + dev_id); - printf("Waiting for connection...\n"); + printf("Waiting for connection for device %d...\n", dev_id); out: return status; @@ -1033,10 +1036,10 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, for (int cdev = 0; cdev < dev_count; cdev++) { for (int sdev = 0; sdev < dev_count; sdev++) { - int ldev = is_server ? sdev : cdev; - int rdev = is_server ? cdev : sdev; + int ldev = is_server ? sdev : cdev; + int rdev = is_server ? cdev : sdev; ucp_worker = dev_ucp_contexts[ldev].ucp_worker; - ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev]; + ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev]; /* Push the right GPU context */ cudaSetDevice(ldev); cudaFree(0); @@ -1081,30 +1084,36 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, char *listen_addr, send_recv_type_t send_recv_type) { - ucx_server_ctx_t context; + ucx_server_ctx_t server_contexts[dev_count]; ucs_status_t status; int ret; - /* Initialize the server's context. */ - context.conn_request = NULL; - context.client_id = 0; /* Create a listener for connection establishment between client and server. * This listener will stay open for listening to incoming connection * requests from the client. - * The listener is created on a worker. We create only one listener on one of - * the workers, and will use it for processing the incoming connection requests - * from all other workers (that correspond to multiple GPUs). */ - status = start_server(dev_ucp_contexts[0].ucp_worker, &context, - &context.listener, listen_addr); - if (status != UCS_OK) { - ret = -1; - goto err; + * The listener is created on a worker. We create one listener for each + * device-specific context/worker, and will use it for processing the + * incoming connection requests from all other workers + * (that correspond to multiple GPUs). */ + + for (int dev_id = 0; dev_id < dev_count; dev_id++) { + server_contexts[dev_id].conn_request = NULL; + server_contexts[dev_id].client_id = 0; + server_contexts[dev_id].dev_id = dev_id; + + status = start_server(dev_ucp_contexts[dev_id].ucp_worker, + &server_contexts[dev_id], + listen_addr, dev_id); + if (status != UCS_OK) { + ret = -1; + goto err_listener; + } } - /* Server is always up listening */ + /* Servers are always up listening */ while (1) { - ret = server_create_eps(dev_ucp_contexts, dev_count, &context); + ret = server_create_eps(dev_ucp_contexts, server_contexts, dev_count); if (ret != 0) { goto err_listener; @@ -1121,18 +1130,23 @@ static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count, /* Close all the endpoints to the client */ close_eps(dev_ucp_contexts, dev_count); - /* Reinitialize the server's context to be used for the next client */ - context.conn_request = NULL; - context.client_id = 0; + /* Reinitialize the server contexts to be used for the next client */ + for (int dev_id = 0; dev_id < dev_count; dev_id++) { + server_contexts[dev_id].conn_request = NULL; + server_contexts[dev_id].client_id = 0; + } - printf("Waiting for connection...\n"); + printf("Waiting for next client connections...\n"); } err_ep: close_eps(dev_ucp_contexts, dev_count); err_listener: - ucp_listener_destroy(context.listener); -err: + for (int dev_id = 0; dev_id < dev_count; dev_id++) { + if (server_contexts[dev_id].listener != NULL) { + ucp_listener_destroy(server_contexts[dev_id].listener); + } + } return ret; }