Skip to content

Commit

Permalink
Use a separate listener for each device
Browse files Browse the repository at this point in the history
  • Loading branch information
SeyedMir committed Feb 14, 2025
1 parent 2abc753 commit 00589c7
Showing 1 changed file with 80 additions and 66 deletions.
146 changes: 80 additions & 66 deletions examples/ucp_client_server_multi_dev.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ typedef struct ucx_server_ctx {
volatile ucp_conn_request_h conn_request;
ucp_listener_h listener;
uint64_t client_id;
int dev_id;
} ucx_server_ctx_t;


Expand Down Expand Up @@ -217,7 +218,8 @@ static void err_cb(void *arg, ucp_ep_h ep, ucs_status_t status)
/**
* Set an address for the server to listen on - INADDR_ANY on a well known port.
*/
void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr)
void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr,
int dev_id)
{
struct sockaddr_in *sa_in;
struct sockaddr_in6 *sa_in6;
Expand All @@ -234,7 +236,7 @@ void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr)
sa_in->sin_addr.s_addr = INADDR_ANY;
}
sa_in->sin_family = AF_INET;
sa_in->sin_port = htons(server_port);
sa_in->sin_port = htons(server_port + dev_id);
break;
case AF_INET6:
sa_in6 = (struct sockaddr_in6*)saddr;
Expand All @@ -244,7 +246,7 @@ void set_sock_addr(const char *address_str, struct sockaddr_storage *saddr)
sa_in6->sin6_addr = in6addr_any;
}
sa_in6->sin6_family = AF_INET6;
sa_in6->sin6_port = htons(server_port);
sa_in6->sin6_port = htons(server_port + dev_id);
break;
default:
fprintf(stderr, "Invalid address family");
Expand Down Expand Up @@ -276,15 +278,12 @@ static int client_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
{
/* Client must send one separate connection request to the server for each
* clientGPU-ServerGPU pair of workers. The requests must be sent in the
* order of the client-server GPU id pairs (see server_create_eps).
* Note that all the requests are sent to one signle listener on the
* server. That is, we don't need one listener per worker. */
* order of the client GPU ids so the server knows which clinet GPU the incoming
* connection request is for (see server_create_eps). */
ucp_ep_params_t ep_params;
struct sockaddr_storage connect_addr;
ucs_status_t status;

set_sock_addr(address_str, &connect_addr);

/*
* Endpoint field mask bits:
* UCP_EP_PARAM_FIELD_FLAGS - Use the value of the 'flags' field.
Expand All @@ -308,15 +307,18 @@ static int client_create_eps(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
ep_params.err_handler.arg = NULL;
ep_params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER |
UCP_EP_PARAMS_FLAGS_SEND_CLIENT_ID;
ep_params.sockaddr.addr = (struct sockaddr*)&connect_addr;
ep_params.sockaddr.addrlen = sizeof(connect_addr);

for (int cdev = 0; cdev < dev_count; cdev++) {
for (int sdev = 0; sdev < dev_count; sdev++) {
// Set address with the correct port for each server device
set_sock_addr(address_str, &connect_addr, sdev);
ep_params.sockaddr.addr = (struct sockaddr*)&connect_addr;
ep_params.sockaddr.addrlen = sizeof(connect_addr);
status = ucp_ep_create(dev_ucp_contexts[cdev].ucp_worker, &ep_params,
&dev_ucp_contexts[cdev].ucp_eps[sdev]);
if (status != UCS_OK) {
fprintf(stderr, "failed to connect to %s (%s)\n", address_str,
fprintf(stderr, "failed to connect to %s port %d (%s)\n",
address_str, server_port + sdev,
ucs_status_string(status));
close_eps(dev_ucp_contexts, dev_count);
return -1;
Expand Down Expand Up @@ -905,7 +907,8 @@ static ucs_status_t server_create_ep(ucp_worker_h ucp_worker,
}

static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
int dev_count, ucx_server_ctx_t *context)
ucx_server_ctx_t *server_contexts,
int dev_count)
{
/* Creating server-side eps. The eps are created upon receiving connection
* requests initiated by the client. The client must initiate one request
Expand All @@ -916,15 +919,12 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
* For each connection request, we need to know:
* 1. the client-side UCP worker GPU id associated with the request,
* 2. the server-side UCP worker GPU id that the request wants to target
* We rely on a contract between the client and server: the client
* issues the requests in the order of the client-server GPU id pairs.
* That is, the first request is for client_gpu_0 to server_gpu_0,
* the second for client_gpu_0 to server_gpu_1, and so on. Thus, we can use
* a pair of dev_id counters on the server side to map each request to its
* corresponding client-server GPU ids pair.
* Note that we assume the client and server use the same number of GPUs.
* Otherwise, they need to exchange an initial message to let each other
* know about the number of GPUs they use. */
* For 1, we rely on the fact that the client sends its requests in the
* order of its GPU ids. For 2, we already know it because we have a
* separate listener per server GPU. Note that we assume the client and
* server use the same number of GPUs. Otherwise, they need to exchange an
* initial message to let each other know about the number of GPUs they use.
*/
ucs_status_t status;
for (int cdev = 0; cdev < dev_count; cdev++) { /* server GPUs */
for (int sdev = 0; sdev < dev_count; sdev++) { /* client GPUs */
Expand All @@ -933,12 +933,12 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
* which the server's connection request callback is invoked,
* i.e. several clients are trying to connect in parallel, the
* server will handle only the first one and reject the rest. */
while (context->conn_request == NULL) {
ucp_worker_progress(dev_ucp_contexts[0].ucp_worker);
while (server_contexts[sdev].conn_request == NULL) {
ucp_worker_progress(dev_ucp_contexts[sdev].ucp_worker);
}

status = server_create_ep(dev_ucp_contexts[sdev].ucp_worker,
context->conn_request,
server_contexts[sdev].conn_request,
&dev_ucp_contexts[sdev].ucp_eps[cdev]);
if (status != UCS_OK) {
close_eps(dev_ucp_contexts, dev_count);
Expand All @@ -950,8 +950,9 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
dev_ucp_contexts[sdev].ep_count++;

/* Now we are ready to accept the next request, but only
* for the rest of the GPUs from the same client. */
context->conn_request = NULL;
* for the rest of the GPUs from the same client. So, do not
* reset client_id. */
server_contexts[sdev].conn_request = NULL;
}
}

Expand All @@ -963,7 +964,7 @@ static ucs_status_t server_create_eps(dev_ucp_ctx_t *dev_ucp_contexts,
*/
static ucs_status_t
start_server(ucp_worker_h ucp_worker, ucx_server_ctx_t *context,
ucp_listener_h *listener_p, const char *address_str)
const char *address_str, int dev_id)
{
struct sockaddr_storage listen_addr;
ucp_listener_params_t params;
Expand All @@ -972,37 +973,39 @@ start_server(ucp_worker_h ucp_worker, ucx_server_ctx_t *context,
char ip_str[IP_STRING_LEN];
char port_str[PORT_STRING_LEN];

set_sock_addr(address_str, &listen_addr);
set_sock_addr(address_str, &listen_addr, dev_id);

params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR |
UCP_LISTENER_PARAM_FIELD_CONN_HANDLER;
params.sockaddr.addr = (const struct sockaddr*)&listen_addr;
params.sockaddr.addrlen = sizeof(listen_addr);
params.conn_handler.cb = server_conn_handle_cb;
params.conn_handler.arg = context;
params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR |
UCP_LISTENER_PARAM_FIELD_CONN_HANDLER;
params.sockaddr.addr = (const struct sockaddr*)&listen_addr;
params.sockaddr.addrlen = sizeof(listen_addr);
params.conn_handler.cb = server_conn_handle_cb;
params.conn_handler.arg = context;

/* Create a listener on the server side to listen on the given address.*/
status = ucp_listener_create(ucp_worker, &params, listener_p);
status = ucp_listener_create(ucp_worker, &params, &context->listener);
if (status != UCS_OK) {
fprintf(stderr, "failed to listen (%s)\n", ucs_status_string(status));
fprintf(stderr, "failed to listen for device %d (%s)\n",
dev_id, ucs_status_string(status));
goto out;
}

/* Query the created listener to get the port it is listening on. */
attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR;
status = ucp_listener_query(*listener_p, &attr);
status = ucp_listener_query(context->listener, &attr);
if (status != UCS_OK) {
fprintf(stderr, "failed to query the listener (%s)\n",
ucs_status_string(status));
ucp_listener_destroy(*listener_p);
fprintf(stderr, "failed to query the listener for device %d (%s)\n",
dev_id, ucs_status_string(status));
ucp_listener_destroy(context->listener);
goto out;
}

fprintf(stderr, "server is listening on IP %s port %s\n",
fprintf(stderr, "server is listening on IP %s port %s for device %d\n",
sockaddr_get_ip_str(&attr.sockaddr, ip_str, IP_STRING_LEN),
sockaddr_get_port_str(&attr.sockaddr, port_str, PORT_STRING_LEN));
sockaddr_get_port_str(&attr.sockaddr, port_str, PORT_STRING_LEN),
dev_id);

printf("Waiting for connection...\n");
printf("Waiting for connection for device %d...\n", dev_id);

out:
return status;
Expand Down Expand Up @@ -1033,10 +1036,10 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,

for (int cdev = 0; cdev < dev_count; cdev++) {
for (int sdev = 0; sdev < dev_count; sdev++) {
int ldev = is_server ? sdev : cdev;
int rdev = is_server ? cdev : sdev;
int ldev = is_server ? sdev : cdev;
int rdev = is_server ? cdev : sdev;
ucp_worker = dev_ucp_contexts[ldev].ucp_worker;
ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev];
ucp_ep = dev_ucp_contexts[ldev].ucp_eps[rdev];
/* Push the right GPU context */
cudaSetDevice(ldev);
cudaFree(0);
Expand Down Expand Up @@ -1081,30 +1084,36 @@ static int client_server_do_work(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
char *listen_addr, send_recv_type_t send_recv_type)
{
ucx_server_ctx_t context;
ucx_server_ctx_t server_contexts[dev_count];
ucs_status_t status;
int ret;

/* Initialize the server's context. */
context.conn_request = NULL;
context.client_id = 0;

/* Create a listener for connection establishment between client and server.
* This listener will stay open for listening to incoming connection
* requests from the client.
* The listener is created on a worker. We create only one listener on one of
* the workers, and will use it for processing the incoming connection requests
* from all other workers (that correspond to multiple GPUs). */
status = start_server(dev_ucp_contexts[0].ucp_worker, &context,
&context.listener, listen_addr);
if (status != UCS_OK) {
ret = -1;
goto err;
* The listener is created on a worker. We create one listener for each
* device-specific context/worker, and will use it for processing the
* incoming connection requests from all other workers
* (that correspond to multiple GPUs). */

for (int dev_id = 0; dev_id < dev_count; dev_id++) {
server_contexts[dev_id].conn_request = NULL;
server_contexts[dev_id].client_id = 0;
server_contexts[dev_id].dev_id = dev_id;

status = start_server(dev_ucp_contexts[dev_id].ucp_worker,
&server_contexts[dev_id],
listen_addr, dev_id);
if (status != UCS_OK) {
ret = -1;
goto err_listener;
}
}

/* Server is always up listening */
/* Servers are always up listening */
while (1) {
ret = server_create_eps(dev_ucp_contexts, dev_count, &context);
ret = server_create_eps(dev_ucp_contexts, server_contexts, dev_count);

if (ret != 0) {
goto err_listener;
Expand All @@ -1121,18 +1130,23 @@ static int run_server(dev_ucp_ctx_t *dev_ucp_contexts, int dev_count,
/* Close all the endpoints to the client */
close_eps(dev_ucp_contexts, dev_count);

/* Reinitialize the server's context to be used for the next client */
context.conn_request = NULL;
context.client_id = 0;
/* Reinitialize the server contexts to be used for the next client */
for (int dev_id = 0; dev_id < dev_count; dev_id++) {
server_contexts[dev_id].conn_request = NULL;
server_contexts[dev_id].client_id = 0;
}

printf("Waiting for connection...\n");
printf("Waiting for next client connections...\n");
}

err_ep:
close_eps(dev_ucp_contexts, dev_count);
err_listener:
ucp_listener_destroy(context.listener);
err:
for (int dev_id = 0; dev_id < dev_count; dev_id++) {
if (server_contexts[dev_id].listener != NULL) {
ucp_listener_destroy(server_contexts[dev_id].listener);
}
}
return ret;
}

Expand Down

0 comments on commit 00589c7

Please sign in to comment.