Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to register a user callback for receiving AM messages #785

Draft
wants to merge 6 commits into
base: branch-0.22
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/test_send_recv_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,50 @@ async def test_send_recv_bytes(size, blocking_progress_mode, recv_wait, data):
assert recv[0] == bytearray(msg.get())
else:
data["validator"](recv[0], msg)


@pytest.mark.skipif(
not ucp._libs.ucx_api.is_am_supported(), reason="AM only supported in UCX >= 1.11"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("size", msg_sizes)
@pytest.mark.parametrize("blocking_progress_mode", [True, False])
@pytest.mark.parametrize("recv_wait", [True, False])
@pytest.mark.parametrize("data", get_data())
async def test_send_recv_bytes_callback(size, blocking_progress_mode, recv_wait, data):
rndv_thresh = 8192
ucp.init(
options={"RNDV_THRESH": str(rndv_thresh)},
blocking_progress_mode=blocking_progress_mode,
)

recv = []

async def _cb(recv_obj, exception, ep):
recv.append(recv_obj)

ucp.register_am_allocator(data["allocator"], data["memory_type"])
ucp.register_am_recv_callback(_cb)
msg = data["generator"](size)

num_clients = 1
clients = [
await ucp.create_endpoint_from_worker_address(ucp.get_worker_address())
for i in range(num_clients)
]
for c in clients:
if recv_wait:
# By sleeping here we ensure that the listener's
# ep.am_recv call will have to wait, rather than return
# immediately as receive data is already available.
await asyncio.sleep(1)
await c.am_send(msg)
for c in clients:
await c.close()

if data["memory_type"] == "cuda" and msg.nbytes < rndv_thresh:
# Eager messages are always received on the host, if no host
# allocator is registered UCX-Py defaults to `bytearray`.
assert recv[0] == bytearray(msg.get())
else:
data["validator"](recv[0], msg)
79 changes: 54 additions & 25 deletions ucp/_libs/tests/test_server_client_am.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import multiprocessing as mp
import os
import pickle
from functools import partial
from queue import Empty as QueueIsEmpty

Expand Down Expand Up @@ -49,7 +50,9 @@ def get_data():
return ret


def _echo_server(get_queue, put_queue, msg_size, datatype, endpoint_error_handling):
def _echo_server(
get_queue, put_queue, msg_size, datatype, user_callback, endpoint_error_handling
):
"""Server that send received message back to the client

Notice, since it is illegal to call progress() in call-back functions,
Expand All @@ -64,10 +67,6 @@ def _echo_server(get_queue, put_queue, msg_size, datatype, endpoint_error_handli
worker = ucx_api.UCXWorker(ctx)
worker.register_am_allocator(data["allocator"], data["memory_type"])

# A reference to listener's endpoint is stored to prevent it from going
# out of scope too early.
ep = None

def _send_handle(request, exception, msg):
# Notice, we pass `msg` to the handler in order to make sure
# it doesn't go out of scope prematurely.
Expand All @@ -78,20 +77,28 @@ def _recv_handle(recv_obj, exception, ep):
msg = Array(recv_obj)
ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,))

def _listener_handler(conn_request):
global ep
ep = ucx_api.UCXEndpoint.create_from_conn_request(
worker, conn_request, endpoint_error_handling=endpoint_error_handling,
)
if user_callback is True:
worker.register_am_recv_callback(_recv_handle)
put_queue.put(pickle.dumps(worker.get_address()))
else:
# A reference to listener's endpoint is stored to prevent it from going
# out of scope too early.
ep = None

# Wireup
ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))
def _listener_handler(conn_request):
global ep
ep = ucx_api.UCXEndpoint.create_from_conn_request(
worker, conn_request, endpoint_error_handling=endpoint_error_handling,
)

# Data
ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))
# Wireup
ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
put_queue.put(listener.port)
# Data
ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,))

listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
put_queue.put(listener.port)

while True:
worker.progress()
Expand All @@ -103,7 +110,9 @@ def _listener_handler(conn_request):
break


def _echo_client(msg_size, datatype, port, endpoint_error_handling):
def _echo_client(
msg_size, datatype, user_callback, server_info, endpoint_error_handling
):
data = get_data()[datatype]

ctx = ucx_api.UCXContext(
Expand All @@ -113,9 +122,16 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling):
worker = ucx_api.UCXWorker(ctx)
worker.register_am_allocator(data["allocator"], data["memory_type"])

ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
)
if user_callback is True:
server_worker_addr = pickle.loads(server_info)
ep = ucx_api.UCXEndpoint.create_from_worker_address(
worker, server_worker_addr, endpoint_error_handling=endpoint_error_handling,
)
else:
port = server_info
ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
)

# The wireup message is sent to ensure endpoints are connected, otherwise
# UCX may not perform any rendezvous transfers.
Expand All @@ -134,10 +150,14 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling):
recv_wireup = bytearray(recv_wireup)
assert bytearray(recv_wireup) == send_wireup

if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH:
if (
data["memory_type"] == ucx_api.AllocatorType.CUDA
and send_data.nbytes < RNDV_THRESH
):
# Eager messages are always received on the host, if no host
# allocator is registered UCX-Py defaults to `bytearray`.
assert recv_data == bytearray(send_data.get())
else:
data["validator"](recv_data, send_data)


Expand All @@ -146,18 +166,27 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling):
)
@pytest.mark.parametrize("msg_size", [10, 2 ** 24])
@pytest.mark.parametrize("datatype", get_data().keys())
def test_server_client(msg_size, datatype):
@pytest.mark.parametrize("user_callback", [False, True])
def test_server_client(msg_size, datatype, user_callback):
endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)

put_queue, get_queue = mp.Queue(), mp.Queue()
server = mp.Process(
target=_echo_server,
args=(put_queue, get_queue, msg_size, datatype, endpoint_error_handling),
args=(
put_queue,
get_queue,
msg_size,
datatype,
user_callback,
endpoint_error_handling,
),
)
server.start()
port = get_queue.get()
server_info = get_queue.get()
client = mp.Process(
target=_echo_client, args=(msg_size, datatype, port, endpoint_error_handling)
target=_echo_client,
args=(msg_size, datatype, user_callback, server_info, endpoint_error_handling),
)
client.start()
client.join(timeout=10)
Expand Down
4 changes: 4 additions & 0 deletions ucp/_libs/transfer_am.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def am_recv_nb(
IF CY_UCP_AM_SUPPORTED:
worker = ep.worker

if worker.is_am_recv_callback_registered():
raise RuntimeError("`am_recv_nb` cannot be used when a callback was "
"registered to worker with `register_am_recv_callback`")

if cb_args is None:
cb_args = ()
if cb_kwargs is None:
Expand Down
107 changes: 74 additions & 33 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -135,45 +135,65 @@ cdef class UCXEndpoint(UCXObject):
def __init__(
self,
UCXWorker worker,
uintptr_t params_as_int,
bint endpoint_error_handling
uintptr_t params_as_int=0,
endpoint_error_handling=None,
uintptr_t ep_as_int=0,
):
"""The Constructor"""

cdef ucp_err_handler_cb_t err_cb
cdef uintptr_t ep_status
cdef ucp_ep_params_t *params = NULL
cdef ucp_ep_h ucp_ep
cdef ucs_status_t status

assert worker.initialized
self.worker = worker
self._inflight_msgs = set()
self._endpoint_error_handling = endpoint_error_handling

cdef ucp_err_handler_cb_t err_cb
cdef uintptr_t ep_status
err_cb, ep_status = (
_get_error_callback(worker._context._config["TLS"], endpoint_error_handling)
)
if params_as_int == 0 and ep_as_int == 0:
raise ValueError("At least one of `params_as_int` or `ep` must be set.")
elif params_as_int != 0 and ep_as_int != 0:
raise ValueError("`params_as_int` and `ep` are mutually exclusive.")
elif params_as_int != 0:
err_cb, ep_status = (
_get_error_callback(
worker._context._config["TLS"],
endpoint_error_handling
)
)

cdef ucp_ep_params_t *params = <ucp_ep_params_t *>params_as_int
if err_cb == NULL:
params.err_mode = UCP_ERR_HANDLING_MODE_NONE
params = <ucp_ep_params_t *>params_as_int
if err_cb == NULL:
params.err_mode = UCP_ERR_HANDLING_MODE_NONE
else:
params.err_mode = UCP_ERR_HANDLING_MODE_PEER
params.err_handler.cb = err_cb
params.err_handler.arg = <void *>self

status = ucp_ep_create(worker._handle, params, &ucp_ep)
assert_ucs_status(status)

self._handle = ucp_ep
self._status = <uintptr_t>ep_status
self.add_handle_finalizer(
_ucx_endpoint_finalizer,
int(<uintptr_t>ucp_ep),
int(<uintptr_t>ep_status),
endpoint_error_handling,
worker,
self._inflight_msgs,
)
worker.add_child(self)
else:
params.err_mode = UCP_ERR_HANDLING_MODE_PEER
params.err_handler.cb = err_cb
params.err_handler.arg = <void *>self

cdef ucp_ep_h ucp_ep
cdef ucs_status_t status = ucp_ep_create(worker._handle, params, &ucp_ep)
assert_ucs_status(status)

self._handle = ucp_ep
self._status = <uintptr_t>ep_status
self._endpoint_error_handling = endpoint_error_handling
self.add_handle_finalizer(
_ucx_endpoint_finalizer,
int(<uintptr_t>ucp_ep),
int(<uintptr_t>ep_status),
endpoint_error_handling,
worker,
self._inflight_msgs,
)
worker.add_child(self)
self._handle = <ucp_ep_h>ep_as_int
self._status = <uintptr_t>UCS_OK
self.add_handle_finalizer(
lambda handle: None,
int(<uintptr_t>ucp_ep),
)
worker.add_child(self)

@classmethod
def create(
Expand All @@ -200,7 +220,11 @@ cdef class UCXEndpoint(UCXObject):
raise MemoryError("Failed allocation of sockaddr")

try:
return cls(worker, <uintptr_t>params, endpoint_error_handling)
return cls(
worker,
params_as_int=<uintptr_t>params,
endpoint_error_handling=endpoint_error_handling
)
finally:
c_util_sockaddr_free(&params.sockaddr)
free(<void *>params)
Expand All @@ -221,7 +245,11 @@ cdef class UCXEndpoint(UCXObject):
params.address = address._address

try:
return cls(worker, <uintptr_t>params, endpoint_error_handling)
return cls(
worker,
params_as_int=<uintptr_t>params,
endpoint_error_handling=endpoint_error_handling
)
finally:
free(<void *>params)

Expand All @@ -243,10 +271,23 @@ cdef class UCXEndpoint(UCXObject):
params.conn_request = <ucp_conn_request_h> conn_request

try:
return cls(worker, <uintptr_t>params, endpoint_error_handling)
return cls(
worker,
params_as_int=<uintptr_t>params,
endpoint_error_handling=endpoint_error_handling
)
finally:
free(<void *>params)

@classmethod
def create_from_handle(
cls,
UCXWorker worker,
uintptr_t ep,
):
assert worker.initialized
return cls(worker, ep_as_int=ep)

def info(self):
assert self.initialized

Expand Down
Loading