Skip to content

Commit

Permalink
Improve handling of async sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Feb 18, 2024
1 parent 4ffbfcf commit 642b145
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions jupyter_client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# We are using compare_digest to limit the surface of timing attacks
import zmq.asyncio
from jupyter_core.utils import run_sync
from tornado.ioloop import IOLoop
from traitlets import (
Any,
Expand Down Expand Up @@ -812,7 +813,13 @@ def send(

if isinstance(stream, zmq.asyncio.Socket):
assert stream is not None # type:ignore[unreachable]
stream = zmq.Socket.shadow(stream.underlying)

async def send_multipart(*args, **kwargs):
return await stream.send_multipart(*args, **kwargs)

send_func = run_sync(send_multipart)
else:
send_func = stream.send_multipart

if isinstance(msg_or_type, (Message, dict)):
# We got a Message or message dict, not a msg_type so don't
Expand Down Expand Up @@ -856,11 +863,11 @@ def send(

if stream and buffers and track and not copy:
# only really track when we are doing zero-copy buffers
tracker = stream.send_multipart(to_send, copy=False, track=True)
tracker = send_func(to_send, copy=False, track=True)
elif stream:
# use dummy tracker, which will be done immediately
tracker = DONE
stream.send_multipart(to_send, copy=copy)
send_func(to_send, copy=copy)
else:
tracker = DONE

Expand Down Expand Up @@ -907,8 +914,15 @@ def send_raw(
to_send.append(self.sign(msg_list[0:4]))
to_send.extend(msg_list)
if isinstance(stream, zmq.asyncio.Socket):
stream = zmq.Socket.shadow(stream.underlying)
stream.send_multipart(to_send, flags, copy=copy)
assert stream is not None # type:ignore[unreachable]

async def send_multipart(*args, **kwargs):
return await stream.send_multipart(*args, **kwargs)

send_func = run_sync(send_multipart)
else:
send_func = stream.send_multipart
send_func(to_send, flags, copy=copy)

def recv(
self,
Expand All @@ -932,11 +946,18 @@ def recv(
"""
if isinstance(socket, ZMQStream): # type:ignore[unreachable]
socket = socket.socket # type:ignore[unreachable]

if isinstance(socket, zmq.asyncio.Socket):
socket = zmq.Socket.shadow(socket.underlying)

async def recv_multipart(*args, **kwargs):
return await socket.recv_multipart(*args, **kwargs)

recv_func = run_sync(recv_multipart)
else:
recv_func = socket.recv_multipart

try:
msg_list = socket.recv_multipart(mode, copy=copy)
msg_list = recv_func(mode, copy=copy)
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
# We can convert EAGAIN to None as we know in this case
Expand Down

0 comments on commit 642b145

Please sign in to comment.