From b6fb1bc529ebe6d083c9b43ed7ddd9a9b58469e5 Mon Sep 17 00:00:00 2001 From: Daniele Date: Mon, 30 Oct 2023 11:41:10 +0100 Subject: [PATCH] Testing for disconnection scenarios (#140) * coverage for disconnection scenarios * fixing a few flows - fixing tests * adding new tests, fixing some bugs detected * some further tests and updates * updating examples * adding test for send_batch when the connnection get terminated by the server * adding test for send_batch when the connnection get terminated by the server * Handle disconnection Signed-off-by: Gabriele Santomaggio * reformatting * avoid rising on_close callback for locator connections --------- Signed-off-by: Gabriele Santomaggio Co-authored-by: Gabriele Santomaggio --- .../consumer_handle_connections_issues.py | 47 ------ ...er_handle_connections_issues_with_close.py | 63 ++++---- ...er_handle_connections_issues_with_close.py | 20 ++- ...ream_consumer_handle_connections_issues.py | 8 +- ...ream_producer_handle_connections_issues.py | 44 +++--- rstream/client.py | 98 +++++------- rstream/consumer.py | 9 +- rstream/producer.py | 21 ++- rstream/superstream_consumer.py | 1 + rstream/superstream_producer.py | 4 +- tests/http_requests.py | 31 ++++ tests/test_consumer.py | 93 +++++++++++ tests/test_producer.py | 146 ++++++++++++++++++ tests/util.py | 44 +++++- 14 files changed, 454 insertions(+), 175 deletions(-) delete mode 100644 docs/examples/check_connection_broken/consumer_handle_connections_issues.py diff --git a/docs/examples/check_connection_broken/consumer_handle_connections_issues.py b/docs/examples/check_connection_broken/consumer_handle_connections_issues.py deleted file mode 100644 index 66a27e5..0000000 --- a/docs/examples/check_connection_broken/consumer_handle_connections_issues.py +++ /dev/null @@ -1,47 +0,0 @@ -import asyncio -import signal - -from rstream import ( - AMQPMessage, - Consumer, - DisconnectionErrorInfo, - MessageContext, - amqp_decoder, -) - -STREAM = "my-test-stream" - - -async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: - print( - "connection has been closed from stream: " - + str(disconnection_info.streams) - + " for reason: " - + disconnection_info.reason - ) - - -async def consume(): - consumer = Consumer( - host="localhost", - port=5552, - vhost="/", - username="guest", - password="guest", - connection_closed_handler=on_connection_closed, - ) - - loop = asyncio.get_event_loop() - loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close())) - - async def on_message(msg: AMQPMessage, message_context: MessageContext): - stream = message_context.consumer.get_stream(message_context.subscriber_name) - offset = message_context.offset - print("Got message: {} from stream {}, offset {}".format(msg, stream, offset)) - - await consumer.start() - await consumer.subscribe(stream=STREAM, callback=on_message, decoder=amqp_decoder) - await consumer.run() - - -asyncio.run(consume()) diff --git a/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py b/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py index 7718da3..c7c2817 100644 --- a/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py +++ b/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py @@ -1,56 +1,59 @@ import asyncio -import time +import signal from rstream import ( AMQPMessage, + Consumer, DisconnectionErrorInfo, - Producer, + MessageContext, + amqp_decoder, ) STREAM = "my-test-stream" -MESSAGES = 10000000 +COUNT = 0 connection_is_closed = False -async def publish(): +async def consume(): async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: print( "connection has been closed from stream: " + str(disconnection_info.streams) + " for reason: " - + disconnection_info.reason + + str(disconnection_info.reason) ) global connection_is_closed - connection_is_closed = True - await producer.close() - - # avoid to use async context in this case as we are closing the producer ourself in the callback - # in this case we avoid double closing - producer = Producer( - "localhost", username="guest", password="guest", connection_closed_handler=on_connection_closed + # avoid multiple simultaneous disconnection to call close multiple times + if connection_is_closed is False: + await consumer.close() + connection_is_closed = True + + consumer = Consumer( + host="localhost", + port=5552, + vhost="/", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, ) - await producer.start() - # create a stream if it doesn't already exist - await producer.create_stream(STREAM, exists_ok=True) - - # sending a million of messages in AMQP format - start_time = time.perf_counter() + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close())) - for i in range(MESSAGES): - amqp_message = AMQPMessage( - body="hello: {}".format(i), - ) - # send is asynchronous - if connection_is_closed is False: - await producer.send(stream=STREAM, message=amqp_message) - else: - break + async def on_message(msg: AMQPMessage, message_context: MessageContext): + stream = message_context.consumer.get_stream(message_context.subscriber_name) + offset = message_context.offset + global COUNT + COUNT = COUNT + 1 + if COUNT % 1000000 == 0: + # print("Got message: {} from stream {}, offset {}".format(msg, stream, offset)) + print("consumed 1 million messages") - end_time = time.perf_counter() - print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") + await consumer.start() + await consumer.subscribe(stream=STREAM, callback=on_message, decoder=amqp_decoder) + await consumer.run() -asyncio.run(publish()) +asyncio.run(consume()) diff --git a/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py b/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py index 828a998..84406b4 100644 --- a/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py +++ b/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py @@ -8,7 +8,7 @@ ) STREAM = "my-test-stream" -MESSAGES = 1000000 +MESSAGES = 10000000 connection_is_closed = False @@ -18,36 +18,42 @@ async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> No "connection has been closed from stream: " + str(disconnection_info.streams) + " for reason: " - + disconnection_info.reason + + str(disconnection_info.reason) ) - # clean close or reconnect - await producer.close() global connection_is_closed connection_is_closed = True + print("creating Producer") + # producer will be closed at the end by the async context manager + # both if connection is still alive or not async with Producer( "localhost", username="guest", password="guest", connection_closed_handler=on_connection_closed ) as producer: + # create a stream if it doesn't already exist await producer.create_stream(STREAM, exists_ok=True) # sending a million of messages in AMQP format start_time = time.perf_counter() - global connection_is_closed + print("Sending MESSAGES") for i in range(MESSAGES): amqp_message = AMQPMessage( body="hello: {}".format(i), ) # send is asynchronous + global connection_is_closed if connection_is_closed is False: await producer.send(stream=STREAM, message=amqp_message) else: break - end_time = time.perf_counter() - print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") + if i % 10000 == 0: + print("sent 10000 messages") + + end_time = time.perf_counter() + print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") asyncio.run(publish()) diff --git a/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py b/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py index 6fab178..e42bf16 100644 --- a/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py +++ b/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py @@ -12,6 +12,7 @@ ) count = 0 +connection_is_closed = False async def on_message(msg: AMQPMessage, message_context: MessageContext): @@ -31,7 +32,12 @@ async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> No + " for reason: " + disconnection_info.reason ) - await consumer.close() + + global connection_is_closed + if connection_is_closed is False: + connection_is_closed = True + # avoid multiple simultaneous disconnection to call close multiple times + await consumer.close() consumer = SuperStreamConsumer( host="localhost", diff --git a/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py b/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py index 4b39a09..101ed3a 100644 --- a/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py +++ b/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py @@ -30,10 +30,10 @@ async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> No global connection_is_closed connection_is_closed = True - await super_stream_producer.close() - - # avoiding using async context as we close the producer ourself in on_connection_closed callback - super_stream_producer = SuperStreamProducer( + # super_stream_producer will be closed by the async context manager + # both if connection is still alive or not + print("creating super_stream producer") + async with SuperStreamProducer( "localhost", username="guest", password="guest", @@ -41,25 +41,27 @@ async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> No routing=RouteType.Hash, connection_closed_handler=on_connection_closed, super_stream=SUPER_STREAM, - ) - - await super_stream_producer.start() - - # sending a million of messages in AMQP format - start_time = time.perf_counter() - global connection_is_closed + ) as super_stream_producer: - for i in range(MESSAGES): - amqp_message = AMQPMessage( - body="hello: {}".format(i), - application_properties={"id": "{}".format(i)}, - ) + # sending a million of messages in AMQP format + start_time = time.perf_counter() + global connection_is_closed - # send is asynchronous - if connection_is_closed is False: - await super_stream_producer.send(message=amqp_message) - else: - break + print("sending messages") + for i in range(MESSAGES): + amqp_message = AMQPMessage( + body="hello: {}".format(i), + application_properties={"id": "{}".format(i)}, + ) + + # send is asynchronous + if connection_is_closed is False: + await super_stream_producer.send(message=amqp_message) + else: + break + + if i % 10000 == 0: + print("sent 10000 MESSAGES") end_time = time.perf_counter() print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") diff --git a/rstream/client.py b/rstream/client.py index 1f8ac3c..7002d83 100644 --- a/rstream/client.py +++ b/rstream/client.py @@ -134,11 +134,13 @@ def add_handler( def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> None: if name is not None: - if frame_cls in self._handlers: - del self._handlers[frame_cls][name] + if name in self._handlers[frame_cls]: + if frame_cls in self._handlers: + del self._handlers[frame_cls][name] else: - if frame_cls in self._handlers: - self._handlers[frame_cls].clear() + if name in self._handlers[frame_cls]: + if frame_cls in self._handlers: + self._handlers[frame_cls].clear() def is_connection_alive(self) -> bool: return self._is_not_closed @@ -147,16 +149,11 @@ async def send_frame(self, frame: schema.Frame) -> None: logger.debug("Sending frame: %s", frame) assert self._conn try: - await self._conn.write_frame(frame) + if self.is_connection_alive(): + await self._conn.write_frame(frame) except socket.error: self._is_not_closed = False - if self._connection_closed_handler is None: - logger.exception("TCP connection closed") - else: - connection_error_info = DisconnectionErrorInfo("Socket Error", self._streams) - result = self._connection_closed_handler(connection_error_info) - if result is not None and inspect.isawaitable(result): - await result + logger.debug("TCP connection closed") def wait_frame( self, @@ -214,53 +211,41 @@ async def _run_delivery_handlers(self, subscriber_name: str, handler: HT[FT]): async def _listener(self) -> None: assert self._conn - while True: - try: - if self.is_connection_alive(): - frame = await self._conn.read_frame() - except ConnectionClosed: - - if self._connection_closed_handler is not None and self.is_connection_alive(): + try: + while self.is_connection_alive(): + frame = await self._conn.read_frame() + + if not self.is_connection_alive(): + break + + logger.debug("Received frame: %s", frame) + + _key = frame.key, frame.corr_id + fut = self._waiters.get(_key) + if fut is not None: + fut.set_result(frame) + del self._waiters[_key] + + for subscriber_name, handler in list(self._handlers.get(frame.__class__, {}).items()): + try: + if frame.__class__ == schema.Deliver: + await self._frames[subscriber_name].put(frame) + else: + maybe_coro = handler(frame) + if maybe_coro is not None: + await maybe_coro + + except BaseException: + logger.debug("Error while running handler %s of frame %s", handler, frame) + except (ConnectionClosed, socket.error): + if self._connection_closed_handler is not None: + self._is_not_closed = False + # don't raise for locator connections without streams + if len(self._streams) > 0: connection_error_info = DisconnectionErrorInfo("Connection Closed", self._streams) result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result - else: - logger.exception("TCP connection closed") - - self._is_not_closed = False - break - except socket.error: - if self._conn is not None: - if self._connection_closed_handler is not None and self.is_connection_alive(): - connection_error_info = DisconnectionErrorInfo("Socket Error", self._streams) - result = self._connection_closed_handler(connection_error_info) - if result is not None and inspect.isawaitable(result): - await result - else: - logger.debug("TCP connection closed") - self._is_not_closed = False - break - - logger.debug("Received frame: %s", frame) - - _key = frame.key, frame.corr_id - fut = self._waiters.get(_key) - if fut is not None: - fut.set_result(frame) - del self._waiters[_key] - - for subscriber_name, handler in list(self._handlers.get(frame.__class__, {}).items()): - try: - if frame.__class__ == schema.Deliver: - await self._frames[subscriber_name].put(frame) - else: - maybe_coro = handler(frame) - if maybe_coro is not None: - await maybe_coro - - except Exception: - logger.exception("Error while running handler %s of frame %s", handler, frame) def _start_heartbeat(self) -> None: self.start_task("heartbeat_sender", self._heartbeat_sender()) @@ -320,6 +305,7 @@ async def close(self) -> None: logger.exception("exception in client close() sync_request", exc) self._is_not_closed = False + await asyncio.sleep(0.2) await self.stop_task("listener") for subscriber_name in self._frames: @@ -714,7 +700,7 @@ async def new( return client async def close(self) -> None: - for client in self._clients.values(): + for client in list(self._clients.values()): await client.close() self._clients.clear() diff --git a/rstream/consumer.py b/rstream/consumer.py index acdf369..c2ef64b 100644 --- a/rstream/consumer.py +++ b/rstream/consumer.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import random import ssl from dataclasses import dataclass @@ -29,6 +30,7 @@ MT = TypeVar("MT") CB = Annotated[Callable[[MT, Any], Union[None, Awaitable[None]]], "Message callback type"] CB_CONN = Annotated[Callable[[MT], Union[None, Awaitable[None]]], "Message callback type"] +logger = logging.getLogger(__name__) @dataclass @@ -243,7 +245,12 @@ async def unsubscribe(self, subscriber_name: str) -> None: schema.Deliver, name=subscriber.reference, ) - await subscriber.client.unsubscribe(subscriber.subscription_id) + try: + await asyncio.wait_for(subscriber.client.unsubscribe(subscriber.subscription_id), 5) + except asyncio.TimeoutError: + logger.debug("timeout when closing consumer and deleting publisher") + except BaseException as exc: + logger.debug("exception in delete_publisher in Producer.close:", exc) del self._subscribers[subscriber_name] async def query_offset(self, stream: str, subscriber_name: str) -> int: diff --git a/rstream/producer.py b/rstream/producer.py index d8cfee6..a8a7561 100644 --- a/rstream/producer.py +++ b/rstream/producer.py @@ -135,17 +135,30 @@ async def start(self) -> None: ) async def close(self) -> None: - self._close_called = True + + # check if we are in a server disconnection situation: + # in this case we need avoid other send + # otherwise if is a normal close() we need to send the last item in batch + for publisher in self._publishers.values(): + if publisher.client.is_connection_alive() is False: + self._close_called = True + # just in this special case give time to all the tasks to complete + await asyncio.sleep(0.2) + break + # flush messages still in buffer if self._default_client is None: return + if self.default_client.is_connection_alive(): if self.task is not None: for stream in self._buffered_messages: await self._publish_buffered_messages(stream) self.task.cancel() - for publisher in self._publishers.values(): + self._close_called = True + + for publisher in list(self._publishers.values()): if publisher.client.is_connection_alive(): try: await asyncio.wait_for(publisher.client.delete_publisher(publisher.id), 5) @@ -229,6 +242,9 @@ async def send_batch( on_publish_confirm: Optional[CB[ConfirmationStatus]] = None, ) -> list[int]: + if self._close_called: + return [] + wrapped_batch = [] for item in batch: wrapped_item = _MessageNotification( @@ -308,7 +324,6 @@ async def _send_batch( publishing_ids_callback[item.callback].add(publishing_id) if len(messages) > 0: - await publisher.client.send_frame( schema.Publish( publisher_id=publisher.id, diff --git a/rstream/superstream_consumer.py b/rstream/superstream_consumer.py index eba8b18..a8bc21c 100644 --- a/rstream/superstream_consumer.py +++ b/rstream/superstream_consumer.py @@ -181,6 +181,7 @@ async def _create_consumer(self) -> Consumer: load_balancer_mode=False, max_retries=self.max_retries, connection_closed_handler=self._connection_closed_handler, + connection_name=self._connection_name, ) await consumer.start() diff --git a/rstream/superstream_producer.py b/rstream/superstream_producer.py index 427873b..fd2fe6e 100644 --- a/rstream/superstream_producer.py +++ b/rstream/superstream_producer.py @@ -145,7 +145,9 @@ async def start(self) -> None: self._routing_strategy = RoutingKeyRoutingStrategy(self.routing_extractor) async def close(self) -> None: + if self._default_client is not None: + await self._default_client.close() + self._default_client = None await self._pool.close() if self._producer is not None: await self._producer.close() - self._default_client = None diff --git a/tests/http_requests.py b/tests/http_requests.py index a8e0274..b778168 100644 --- a/tests/http_requests.py +++ b/tests/http_requests.py @@ -1,4 +1,7 @@ +import urllib.parse + import requests +from requests.auth import HTTPBasicAuth def create_exchange(exchange_name: str) -> int: @@ -22,3 +25,31 @@ def create_binding(exchange_name: str, routing_key: str, stream_name: str): response = requests.post(request, json=data) return response.status_code + + +def get_connections() -> list: + request = "http://localhost:15672/api/connections" + response = requests.get(request, auth=HTTPBasicAuth("guest", "guest")) + response.raise_for_status() + return response.json() + + +def get_connection(name: str) -> bool: + request = "http://guest:guest@localhost:15672/api/connections/" + urllib.parse.quote(name) + response = requests.get(request, auth=HTTPBasicAuth("guest", "guest")) + if response.status_code == 404: + return False + return True + + +def get_connection_present(connection_name: str, connections: list) -> bool: + for connection in connections: + if connection["client_properties"]["connection_name"] == connection_name: + return True + return False + + +def delete_connection(name: str) -> int: + request = "http://guest:guest@localhost:15672/api/connections/" + urllib.parse.quote(name) + response = requests.delete(request, auth=HTTPBasicAuth("guest", "guest")) + return response.status_code diff --git a/tests/test_consumer.py b/tests/test_consumer.py index f86024d..f1b22b0 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import asyncio +import logging from functools import partial import pytest @@ -11,11 +12,13 @@ AMQPMessage, Consumer, ConsumerOffsetSpecification, + DisconnectionErrorInfo, MessageContext, OffsetType, Producer, SuperStreamConsumer, SuperStreamProducer, + amqp_decoder, exceptions, ) @@ -25,10 +28,12 @@ consumer_update_handler_offset, on_message, run_consumer, + task_to_delete_connection, wait_for, ) pytestmark = pytest.mark.asyncio +logger = logging.getLogger(__name__) async def test_create_stream_already_exists(stream: str, consumer: Consumer) -> None: @@ -516,3 +521,91 @@ async def on_message_first(msg: AMQPMessage, message_context: MessageContext): await producer.send_batch(stream, messages) await wait_for(lambda: len(captured) >= 1) + + +async def test_consumer_connection_broke(stream: str) -> None: + + connection_broke = False + stream_disconnected = None + consumer_broke: Consumer + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + nonlocal connection_broke + connection_broke = True + nonlocal consumer_broke + nonlocal stream_disconnected + stream_disconnected = disconnection_info.streams.pop() + + await consumer_broke.close() + + consumer_broke = Consumer( + host="localhost", + port=5552, + vhost="/", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, + connection_name="test-connection", + ) + + async def on_message(msg: AMQPMessage, message_context: MessageContext): + message_context.consumer.get_stream(message_context.subscriber_name) + + asyncio.create_task(task_to_delete_connection("test-connection")) + + await consumer_broke.start() + await consumer_broke.subscribe(stream=stream, callback=on_message, decoder=amqp_decoder) + await consumer_broke.run() + + assert connection_broke is True + assert stream_disconnected == stream + + await asyncio.sleep(1) + + +async def test_super_stream_consumer_connection_broke(super_stream: str) -> None: + + connection_broke = False + streams_disconnected: set[str] = set() + consumer_broke: Consumer + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + nonlocal connection_broke + nonlocal streams_disconnected + # avoiding multiple connection closed to hit + if connection_broke is True: + for stream in disconnection_info.streams: + streams_disconnected.add(stream) + return None + + connection_broke = True + + for stream in disconnection_info.streams: + streams_disconnected.add(stream) + + await super_stream_consumer_broke.close() + + super_stream_consumer_broke = SuperStreamConsumer( + host="localhost", + port=5552, + vhost="/", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, + connection_name="test-connection", + super_stream=super_stream, + ) + + async def on_message(msg: AMQPMessage, message_context: MessageContext): + message_context.consumer.get_stream(message_context.subscriber_name) + + asyncio.create_task(task_to_delete_connection("test-connection")) + + await super_stream_consumer_broke.start() + await super_stream_consumer_broke.subscribe(callback=on_message, decoder=amqp_decoder) + await super_stream_consumer_broke.run() + + assert connection_broke is True + assert "test-super-stream-0" in streams_disconnected + assert "test-super-stream-1" in streams_disconnected + assert "test-super-stream-2" in streams_disconnected diff --git a/tests/test_producer.py b/tests/test_producer.py index 410acf1..043deb9 100644 --- a/tests/test_producer.py +++ b/tests/test_producer.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import asyncio +import logging from functools import partial import pytest @@ -10,8 +11,10 @@ AMQPMessage, CompressionType, Consumer, + DisconnectionErrorInfo, Producer, RawMessage, + RouteType, SuperStreamConsumer, SuperStreamProducer, amqp_decoder, @@ -21,9 +24,12 @@ from .util import ( on_publish_confirm_client_callback, on_publish_confirm_client_callback2, + routing_extractor_generic, + task_to_delete_connection, wait_for, ) +logger = logging.getLogger(__name__) pytestmark = pytest.mark.asyncio @@ -59,6 +65,18 @@ async def test_publishing_sequence(stream: str, producer: Producer, consumer: Co assert captured == [b"one", b"two", b"three"] +async def test_publishing_several_messages(stream: str, producer: Producer, consumer: Consumer) -> None: + captured: list[bytes] = [] + await consumer.subscribe( + stream, callback=lambda message, message_context: captured.append(bytes(message)) + ) + + for i in range(0, 100000): + await producer.send(stream, b"one") + + await wait_for(lambda: len(captured) == 100000) + + async def test_publishing_sequence_subbatching_nocompression( stream: str, producer: Producer, consumer: Consumer ) -> None: @@ -431,3 +449,131 @@ async def publish_with_ids(*ids): await publish_with_ids(1, 2, 3) await wait_for(lambda: len(confirmed_messages) == 3) + + +async def test_producer_connection_broke(stream: str) -> None: + + connection_broke = False + stream_disconnected = None + producer_broke: Producer + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + nonlocal connection_broke + connection_broke = True + nonlocal producer_broke + + nonlocal stream_disconnected + stream_disconnected = disconnection_info.streams.pop() + + await producer_broke.close() + + producer_broke = Producer( + "localhost", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, + connection_name="test-connection", + ) + + await producer_broke.start() + asyncio.create_task(task_to_delete_connection("test-connection")) + + while connection_broke is False: + await producer_broke.send(stream, b"one") + await asyncio.sleep(0) + + await producer_broke.close() + + assert connection_broke is True + assert stream_disconnected == stream + + +async def test_producer_connection_broke_with_send_batch(stream: str) -> None: + + connection_broke = False + stream_disconnected = None + producer_broke: Producer + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + + nonlocal connection_broke + if connection_broke is True: + return None + + connection_broke = True + nonlocal producer_broke + + nonlocal stream_disconnected + stream_disconnected = disconnection_info.streams.pop() + + producer_broke = Producer( + "localhost", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, + connection_name="test-connection", + ) + + await producer_broke.start() + asyncio.create_task(task_to_delete_connection("test-connection")) + + while connection_broke is False: + batch = [] + for i in range(10000): + amqp_message = AMQPMessage( + body="hello: {}".format(i), + ) + batch.append(amqp_message) + await producer_broke.send_batch(stream, batch) # type: ignore + batch.clear() + + await producer_broke.close() + assert connection_broke is True + assert stream_disconnected == stream + + +async def test_super_stream_producer_connection_broke(super_stream: str) -> None: + + connection_broke = False + streams_disconnected: set[str] = set() + producer_broke: Producer + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + nonlocal connection_broke + connection_broke = True + nonlocal producer_broke + + nonlocal streams_disconnected + for stream in disconnection_info.streams: + streams_disconnected.add(stream) + + super_stream_producer_broke = SuperStreamProducer( + "localhost", + username="guest", + password="guest", + routing_extractor=routing_extractor_generic, + routing=RouteType.Hash, + connection_closed_handler=on_connection_closed, + connection_name="test-connection", + super_stream=super_stream, + ) + + await super_stream_producer_broke.start() + + asyncio.create_task(task_to_delete_connection("test-connection")) + i = 0 + while connection_broke is False: + amqp_message = AMQPMessage( + body="hello: {}".format(i), + application_properties={"id": "{}".format(i)}, + ) + i = i + 1 + # send is asynchronous + await super_stream_producer_broke.send(message=amqp_message) + + await super_stream_producer_broke.close() + + assert connection_broke is True + assert "test-super-stream-0" in streams_disconnected + assert "test-super-stream-1" in streams_disconnected + assert "test-super-stream-2" in streams_disconnected diff --git a/tests/util.py b/tests/util.py index 454fbfb..0a35c1a 100644 --- a/tests/util.py +++ b/tests/util.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import asyncio +import logging from collections import defaultdict from typing import Any, Awaitable, Callable, Optional @@ -16,7 +17,23 @@ amqp_decoder, ) +from .http_requests import ( + delete_connection, + get_connection, + get_connection_present, + get_connections, +) + captured: list[bytes] = [] +logger = logging.getLogger(__name__) + + +async def wait_for(condition, timeout=1): + async def _wait(): + while not condition(): + await asyncio.sleep(0.01) + + await asyncio.wait_for(_wait(), timeout) async def consumer_update_handler_next(is_active: bool, event_context: EventContext) -> OffsetSpecification: @@ -34,14 +51,6 @@ async def consumer_update_handler_offset(is_active: bool, event_context: EventCo return OffsetSpecification(OffsetType.OFFSET, 10) -async def wait_for(condition, timeout=1): - async def _wait(): - while not condition(): - await asyncio.sleep(0.01) - - await asyncio.wait_for(_wait(), timeout) - - async def on_publish_confirm_client_callback( confirmation: ConfirmationStatus, confirmed_messages: list[int], errored_messages: list[int] ) -> None: @@ -66,6 +75,10 @@ async def routing_extractor(message: AMQPMessage) -> str: return "0" +async def routing_extractor_generic(message: AMQPMessage) -> str: + return message.application_properties["id"] + + async def routing_extractor_for_sac(message: AMQPMessage) -> str: return str(message.properties.message_id) @@ -109,3 +122,18 @@ async def run_consumer( properties=properties, consumer_update_listener=consumer_update_listener, ) + + +async def task_to_delete_connection(connection_name: str) -> None: + + # delay a few seconds before deleting the connection + await asyncio.sleep(5) + + connections = get_connections() + + await wait_for(lambda: get_connection_present(connection_name, connections) is True) + + for connection in connections: + if connection["client_properties"]["connection_name"] == connection_name: + delete_connection(connection["name"]) + await wait_for(lambda: get_connection(connection["name"]) is False)