Skip to content

Commit

Permalink
Merge pull request #166 from mosquito/feafure/connection-close-fixes
Browse files Browse the repository at this point in the history
Feafure/connection close fixes
  • Loading branch information
mosquito authored Feb 10, 2023
2 parents 6cceb15 + df7f130 commit d82244c
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 44 deletions.
6 changes: 3 additions & 3 deletions aiormq/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def drain(self) -> None:
if not self.should_drain:
return

if self.drain_future is not None:
if self.drain_future is not None and not self.drain_future.done():
self.drain_future.set_result(None)

@property
Expand All @@ -218,9 +218,9 @@ def marshall(
for frame in frames:
if should_close:
logger = logging.getLogger(
"aiormq.connection"
"aiormq.connection",
).getChild(
"marshall"
"marshall",
)

logger.warning(
Expand Down
50 changes: 31 additions & 19 deletions aiormq/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Returning(asyncio.Future):
class Channel(Base, AbstractChannel):
# noinspection PyTypeChecker
CONTENT_FRAME_SIZE = len(pamqp.frame.marshal(ContentBody(b""), 0))

CHANNEL_CLOSE_TIMEOUT = 10
confirmations: Dict[int, ConfirmationType]

def __init__(
Expand Down Expand Up @@ -404,11 +404,13 @@ async def _on_close_frame(self, frame: spec.Channel.Close) -> None:
),
)
self.connection.channels.pop(self.number, None)
self.__close_event.set()
raise exc

async def _on_close_ok_frame(self, _: spec.Channel.CloseOk) -> None:
self.connection.channels.pop(self.number, None)
raise ChannelClosed()
self.__close_event.set()
raise ChannelClosed(None, None)

async def _reader(self) -> None:
hooks: Mapping[Any, Tuple[bool, Callable[[Any], Awaitable[None]]]]
Expand All @@ -426,35 +428,45 @@ async def _reader(self) -> None:
spec.Basic.Nack: (False, self._on_confirm_frame),
}

last_exception: Optional[BaseException] = None

try:
while True:
frame = await self._get_frame()
should_add_to_rpc, hook = hooks.get(type(frame), (True, None))

if hook is not None:
try:
await hook(frame)
except asyncio.CancelledError as e:
await self._cancel_tasks(e)
return
await hook(frame)

if should_add_to_rpc:
await self.rpc_frames.put(frame)
except asyncio.CancelledError as e:
self.__close_event.set()
last_exception = e
return
except Exception as e:
await self._cancel_tasks(e)
last_exception = e
raise
finally:
await self.close(
last_exception, timeout=self.CHANNEL_CLOSE_TIMEOUT,
)

@task
async def _on_close(self, exc: Optional[ExceptionType] = None) -> None:
if self.connection.is_opened and not self.__close_event.is_set():
await self.rpc(
spec.Channel.Close(
reply_code=self.__close_reply_code,
class_id=self.__close_class_id,
method_id=self.__close_method_id,
),
timeout=self.connection.connection_tune.heartbeat or None,
)
self.connection.channels.pop(self.number, None)
self.__close_event.set()
if not self.connection.is_opened or self.__close_event.is_set():
return

await self.rpc(
spec.Channel.Close(
reply_code=self.__close_reply_code,
class_id=self.__close_class_id,
method_id=self.__close_method_id,
),
timeout=self.connection.connection_tune.heartbeat or None,
)

await self.__close_event.wait()

async def basic_get(
self, queue: str = "", no_ack: bool = False,
Expand Down
61 changes: 42 additions & 19 deletions aiormq/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
from .base import Base, task
from .channel import Channel
from .exceptions import (
AMQPError, AuthenticationError, ConnectionChannelError, ConnectionClosed,
ConnectionCommandInvalid, ConnectionFrameError, ConnectionInternalError,
ConnectionNotAllowed, ConnectionNotImplemented, ConnectionResourceError,
ConnectionSyntaxError, ConnectionUnexpectedFrame, IncompatibleProtocolError,
ProbableAuthenticationError,
AMQPConnectionError, AMQPError, AuthenticationError, ConnectionChannelError,
ConnectionClosed, ConnectionCommandInvalid, ConnectionFrameError,
ConnectionInternalError, ConnectionNotAllowed, ConnectionNotImplemented,
ConnectionResourceError, ConnectionSyntaxError, ConnectionUnexpectedFrame,
IncompatibleProtocolError, ProbableAuthenticationError,
)
from .tools import Countdown, censor_url

Expand Down Expand Up @@ -162,7 +162,7 @@ async def get_frame(self) -> ReceivedFrame:
raise AMQPFrameError(fp.getvalue())

if self.reader is None:
raise ConnectionError
raise AMQPConnectionError()

fp.write(await self.reader.readexactly(6))

Expand All @@ -179,8 +179,26 @@ async def get_frame(self) -> ReceivedFrame:

fp.write(await self.reader.readexactly(frame_length + 1))
except asyncio.IncompleteReadError as e:
raise AMQPError(
"Server connection unexpectedly closed",
raise AMQPConnectionError(
"Server connection unexpectedly closed. "
f"Read {len(e.partial)} bytes but {e.expected} "
"bytes expected",
) from e
except ConnectionRefusedError as e:
raise AMQPConnectionError(
f"Server connection refused: {e!r}",
) from e
except ConnectionResetError as e:
raise AMQPConnectionError(
f"Server connection reset: {e!r}",
) from e
except ConnectionError as e:
raise AMQPConnectionError(
f"Server connection error: {e!r}",
) from e
except OSError as e:
raise AMQPConnectionError(
f"Server communication error: {e!r}",
) from e

return pamqp.frame.unmarshal(fp.getvalue())
Expand Down Expand Up @@ -372,7 +390,9 @@ async def _rpc(
frame_receiver: FrameReceiver,
wait_response: bool = True,
) -> Optional[FrameTypes]:

writer.write(pamqp.frame.marshal(request, 0))
await writer.drain()

if not wait_response:
return None
Expand Down Expand Up @@ -412,7 +432,7 @@ async def connect(

frame_receiver = FrameReceiver(reader)
except OSError as e:
raise ConnectionError(*e.args) from e
raise AMQPConnectionError(*e.args) from e

frame: Optional[FrameTypes]

Expand Down Expand Up @@ -469,7 +489,7 @@ async def connect(

if not isinstance(frame, spec.Connection.OpenOk):
raise AMQPInternalError("Connection.OpenOk", frame)
except Exception as e:
except BaseException as e:
await self.__close_writer(writer)
await self.close(e)
raise
Expand Down Expand Up @@ -691,11 +711,13 @@ async def __writer(self, writer: asyncio.StreamWriter) -> None:

log.debug("Sending %r to %r", frame, self)

await asyncio.gather(
writer.drain(),
self.__close_writer(writer),
return_exceptions=True,
)
try:
await asyncio.wait_for(
writer.drain(), timeout=self.__heartbeat_grace_timeout,
)
finally:
await self.__close_writer(writer)

raise
finally:
log.debug("Writer exited for %r", self)
Expand All @@ -707,10 +729,11 @@ async def __close_writer(self, writer: asyncio.StreamWriter) -> None:
else:
async def __close_writer(self, writer: asyncio.StreamWriter) -> None:
log.debug("Writer on connection %s closed", self)
if writer.can_write_eof():
writer.write_eof()
writer.close()
await writer.wait_closed()
with suppress(OSError, RuntimeError):
if writer.can_write_eof():
writer.write_eof()
writer.close()
await writer.wait_closed()

@staticmethod
def __check_writer(writer: asyncio.StreamWriter) -> bool:
Expand Down
17 changes: 14 additions & 3 deletions aiormq/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@ class AMQPError(Exception):
reason = "An unspecified AMQP error has occurred: %s"

def __repr__(self) -> str:
return "<%s: %s>" % (self.__class__.__name__, self.reason % self.args)
try:
return "<%s: %s>" % (
self.__class__.__name__, self.reason % self.args,
)
except TypeError:
# FIXME: if you are here file an issue
return f"<{self.__class__.__name__}: {self.args!r}>"


# Backward compatibility
AMQPException = AMQPError


class AMQPConnectionError(AMQPError):
reason = "Connection can not be opened"
class AMQPConnectionError(AMQPError, ConnectionError):
reason = "Unexpected connection problem"

def __repr__(self) -> str:
if self.args:
return f"<{self.__class__.__name__}. {'.'.join(self.args)}>"
return AMQPError.__repr__(self)


class IncompatibleProtocolError(AMQPConnectionError):
Expand Down
82 changes: 82 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import itertools
import os
import ssl
import uuid
Expand Down Expand Up @@ -417,3 +418,84 @@ async def test_connection_stuck(proxy, amqp_url: URL):

with pytest.raises(asyncio.CancelledError):
assert reader_task.result()


class BadNetwork:
def __init__(self, proxy, stair: int, disconnect_time: float):
self.proxy = proxy
self.stair = stair
self.disconnect_time = disconnect_time
self.num_bytes = 0
self.loop = asyncio.get_event_loop()
self.lock = asyncio.Lock()

proxy.set_content_processors(
self.client_to_server,
self.server_to_client,
)

async def disconnect(self):
async with self.lock:
await asyncio.sleep(self.disconnect_time)
await self.proxy.disconnect_all()
self.stair *= 2
self.num_bytes = 0

async def server_to_client(self, chunk: bytes) -> bytes:
async with self.lock:
self.num_bytes += len(chunk)
if self.num_bytes < self.stair:
return chunk
self.loop.create_task(self.disconnect())
return chunk

@staticmethod
def client_to_server(chunk: bytes) -> bytes:
return chunk


DISCONNECT_OFFSETS = [2 << i for i in range(1, 10)]
STAIR_STEPS = list(
itertools.product([0.0, 0.005, 0.05, 0.1], DISCONNECT_OFFSETS)
)
STAIR_STEPS_IDS = [
f"[{i // len(DISCONNECT_OFFSETS)}] {t}-{s}"
for i, (t, s) in enumerate(STAIR_STEPS)
]


@aiomisc.timeout(30)
@pytest.mark.parametrize(
"disconnect_time,stair", STAIR_STEPS,
ids=STAIR_STEPS_IDS
)
async def test_connection_close_stairway(
disconnect_time: float, stair: int, proxy, amqp_url: URL
):
url = amqp_url.with_host(
proxy.proxy_host,
).with_port(
proxy.proxy_port,
).update_query(heartbeat="1")

BadNetwork(proxy, stair, disconnect_time)

async def run():
connection = await aiormq.connect(url)
queue = asyncio.Queue()
channel = await connection.channel()
declare_ok = await channel.queue_declare(auto_delete=True)
await channel.basic_consume(
declare_ok.queue, queue.put, no_ack=True
)

while True:
await channel.basic_publish(
b"test", routing_key=declare_ok.queue
)
message: DeliveredMessage = await queue.get()
assert message.body == b"test"

for _ in range(5):
with pytest.raises(aiormq.AMQPError):
await run()

0 comments on commit d82244c

Please sign in to comment.