From 642b1456f1d98f7b3fc3223dc4168fc15fbded7d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Sun, 18 Feb 2024 16:14:22 -0600 Subject: [PATCH] Improve handling of async sockets --- jupyter_client/session.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/jupyter_client/session.py b/jupyter_client/session.py index c387cd06..1caf9634 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -28,6 +28,7 @@ # We are using compare_digest to limit the surface of timing attacks import zmq.asyncio +from jupyter_core.utils import run_sync from tornado.ioloop import IOLoop from traitlets import ( Any, @@ -812,7 +813,13 @@ def send( if isinstance(stream, zmq.asyncio.Socket): assert stream is not None # type:ignore[unreachable] - stream = zmq.Socket.shadow(stream.underlying) + + async def send_multipart(*args, **kwargs): + return await stream.send_multipart(*args, **kwargs) + + send_func = run_sync(send_multipart) + else: + send_func = stream.send_multipart if isinstance(msg_or_type, (Message, dict)): # We got a Message or message dict, not a msg_type so don't @@ -856,11 +863,11 @@ def send( if stream and buffers and track and not copy: # only really track when we are doing zero-copy buffers - tracker = stream.send_multipart(to_send, copy=False, track=True) + tracker = send_func(to_send, copy=False, track=True) elif stream: # use dummy tracker, which will be done immediately tracker = DONE - stream.send_multipart(to_send, copy=copy) + send_func(to_send, copy=copy) else: tracker = DONE @@ -907,8 +914,15 @@ def send_raw( to_send.append(self.sign(msg_list[0:4])) to_send.extend(msg_list) if isinstance(stream, zmq.asyncio.Socket): - stream = zmq.Socket.shadow(stream.underlying) - stream.send_multipart(to_send, flags, copy=copy) + assert stream is not None # type:ignore[unreachable] + + async def send_multipart(*args, **kwargs): + return await stream.send_multipart(*args, **kwargs) + + send_func = run_sync(send_multipart) + else: + send_func = stream.send_multipart + send_func(to_send, flags, copy=copy) def recv( self, @@ -932,11 +946,18 @@ def recv( """ if isinstance(socket, ZMQStream): # type:ignore[unreachable] socket = socket.socket # type:ignore[unreachable] + if isinstance(socket, zmq.asyncio.Socket): - socket = zmq.Socket.shadow(socket.underlying) + + async def recv_multipart(*args, **kwargs): + return await socket.recv_multipart(*args, **kwargs) + + recv_func = run_sync(recv_multipart) + else: + recv_func = socket.recv_multipart try: - msg_list = socket.recv_multipart(mode, copy=copy) + msg_list = recv_func(mode, copy=copy) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case