Skip to content

Commit

Permalink
fix: Stabilize ZMQ RPC auth and malformed packet handling (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Feb 28, 2024
1 parent d0461d0 commit 1bf72f5
Showing 1 changed file with 43 additions and 33 deletions.
76 changes: 43 additions & 33 deletions src/callosum/lower/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,38 @@ async def recv_message(self) -> AsyncGenerator[Optional[RawHeaderBody], None]:
assert self.transport._sock is not None
while True:
multipart_msg = await self.transport._sock.recv_multipart()
*pre, zmsg_type, raw_header, raw_body = multipart_msg
if zmsg_type == b"PING":
await self.transport._sock.send_multipart([
*pre,
b"PONG",
raw_header,
raw_body,
])
elif zmsg_type == b"UPPER":
if len(pre) > 0:
# server
peer_id = pre[0]
yield RawHeaderBody(raw_header, raw_body, peer_id)
try:
*pre, zmsg_type, raw_header, raw_body = multipart_msg
if zmsg_type == b"PING":
await self.transport._sock.send_multipart([
*pre,
b"PONG",
raw_header,
raw_body,
])
elif zmsg_type == b"UPPER":
if len(pre) > 0:
# server
peer_id = pre[0]
yield RawHeaderBody(raw_header, raw_body, peer_id)
else:
# client
yield RawHeaderBody(raw_header, raw_body, None)
else:
# client
yield RawHeaderBody(raw_header, raw_body, None)
# Ignore if the peer has sent a malformed multipart message.
log.debug(
"ZeroMQRPCConnection.recv_message(): "
"ignoring an invalid message from the peer..."
)
except Exception as e:
# ValueError may happen when there are garbage packets accepted by
# the zmq socket.
log.debug(
"ZeroMQRPCConnection.recv_message(): "
"exception caught in the recv loop, continuing...",
exc_info=e,
)
continue

async def send_message(self, raw_msg: RawHeaderBody) -> None:
assert not self.transport._closed
Expand Down Expand Up @@ -316,14 +332,6 @@ def __init__(
self._attach_monitor = attach_monitor
self._zsock_opts = {**_default_zsock_opts, **(zsock_opts or {})}

async def ping(self, ping_timeout: Optional[int] = None) -> bool:
assert self._main_sock is not None
sock: zmq.asyncio.Socket = self._main_sock
await sock.send_multipart([b"PING", b"", b""])
async with asyncio.timeout(ping_timeout / 1000 if ping_timeout else None):
response = await sock.recv_multipart()
return response[0] == b"PONG"

async def __aenter__(self):
if not self.transport._closed:
return ZeroMQRPCConnection(self.transport)
Expand Down Expand Up @@ -387,9 +395,13 @@ async def ping(self, ping_timeout: Optional[int] = None) -> bool:
assert self._main_sock is not None
sock: zmq.asyncio.Socket = self._main_sock
await sock.send_multipart([b"PING", b"", b""])
async with asyncio.timeout(ping_timeout / 1000 if ping_timeout else None):
response = await sock.recv_multipart()
return response[0] == b"PONG"
if await sock.poll(ping_timeout):
try:
response = await sock.recv_multipart(zmq.NOBLOCK)
except zmq.Again:
return False
return response[0] == b"PONG"
return False

async def __aenter__(self) -> ZeroMQRPCConnection:
if not self.transport._closed:
Expand All @@ -409,13 +421,11 @@ async def __aenter__(self) -> ZeroMQRPCConnection:
client_sock.connect(self.addr.uri)
self._main_sock = client_sock
self.transport._sock = client_sock
try:
await self.ping(
ping_timeout=int(self._handshake_timeout * 1000)
if self._handshake_timeout is not None
else 5000
)
except asyncio.TimeoutError:
if not await self.ping(
ping_timeout=int(self._handshake_timeout * 1000)
if self._handshake_timeout is not None
else 5000
):
raise AuthenticationError
handshake_done = time.perf_counter()
log.debug(
Expand Down

0 comments on commit 1bf72f5

Please sign in to comment.