From 77e517ae778017e84d3ba71fa71b40abd11e2ad0 Mon Sep 17 00:00:00 2001 From: Daniele Palaia Date: Wed, 11 Oct 2023 15:52:41 +0200 Subject: [PATCH] new modifications after rebase --- ...er_handle_connections_issues_with_close.py | 50 ++++++++++++++ ...er_handle_connections_issues_with_close.py | 52 ++++++++++++++ rstream/client.py | 67 ++++++++++--------- rstream/consumer.py | 7 +- rstream/producer.py | 16 +++-- 5 files changed, 152 insertions(+), 40 deletions(-) create mode 100644 docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py create mode 100644 docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py diff --git a/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py b/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py new file mode 100644 index 0000000..501cced --- /dev/null +++ b/docs/examples/check_connection_broken/consumer_handle_connections_issues_with_close.py @@ -0,0 +1,50 @@ +import asyncio +import signal + +from rstream import ( + AMQPMessage, + Consumer, + DisconnectionErrorInfo, + MessageContext, + amqp_decoder, +) + +STREAM = "my-test-stream" + + +async def consume(): + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) + + # clean close or reconnect + await consumer.close() + + consumer = Consumer( + host="localhost", + port=5552, + vhost="/", + username="guest", + password="guest", + connection_closed_handler=on_connection_closed, + ) + + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close())) + + async def on_message(msg: AMQPMessage, message_context: MessageContext): + stream = message_context.consumer.get_stream(message_context.subscriber_name) + offset = message_context.offset + # print("Got message: {} from stream {}, offset {}".format(msg, stream, offset)) + + await consumer.start() + await consumer.subscribe(stream=STREAM, callback=on_message, decoder=amqp_decoder) + print("im here") + await consumer.run() + + +asyncio.run(consume()) diff --git a/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py b/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py new file mode 100644 index 0000000..e2e90a5 --- /dev/null +++ b/docs/examples/check_connection_broken/producer_handle_connections_issues_with_close.py @@ -0,0 +1,52 @@ +import asyncio +import time + +from rstream import ( + AMQPMessage, + DisconnectionErrorInfo, + Producer, +) + +STREAM = "my-test-stream" +MESSAGES = 1000000 +connection_is_closed = False + + +async def publish(): + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) + + # clean close or reconnect + await producer.close() + global connection_is_closed + connection_is_closed = True + + async with Producer( + "localhost", username="guest", password="guest", connection_closed_handler=on_connection_closed + ) as producer: + # create a stream if it doesn't already exist + await producer.create_stream(STREAM, exists_ok=True) + + # sending a million of messages in AMQP format + start_time = time.perf_counter() + + for i in range(MESSAGES): + amqp_message = AMQPMessage( + body="hello: {}".format(i), + ) + # send is asynchronous + if connection_is_closed is False: + await producer.send(stream=STREAM, message=amqp_message) + else: + break + + end_time = time.perf_counter() + print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") + + +asyncio.run(publish()) diff --git a/rstream/client.py b/rstream/client.py index f44cc09..e92385c 100644 --- a/rstream/client.py +++ b/rstream/client.py @@ -99,7 +99,6 @@ def __init__( self._is_not_closed: bool = True self._streams: list[str] = [] - self._conn_is_closed: bool = False def start_task(self, name: str, coro: Awaitable[None]) -> None: assert name not in self._tasks @@ -136,8 +135,10 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non else: self._handlers[frame_cls].clear() - def get_connection(self) -> Optional[Connection]: - return self._conn + def get_is_connection_active(self) -> bool: + if self._conn is None: + return False + return True async def send_frame(self, frame: schema.Frame) -> None: logger.debug("Sending frame: %s", frame) @@ -145,9 +146,9 @@ async def send_frame(self, frame: schema.Frame) -> None: try: await self._conn.write_frame(frame) except socket.error as e: - self._conn_is_closed = True + self._conn = None if self._connection_closed_handler is None: - print("TCP connection closed") + logger.exception("TCP connection closed") else: connection_error_info = DisconnectionErrorInfo(e, self._streams) result = self._connection_closed_handler(connection_error_info) @@ -199,9 +200,10 @@ async def _run_delivery_handlers(self, subscriber_name: str, handler: HT[FT]): while self._is_not_closed: frame_entry = await self._frames[subscriber_name].get() try: - maybe_coro = handler(frame_entry) - if maybe_coro is not None: - await maybe_coro + if self._conn is not None: + maybe_coro = handler(frame_entry) + if maybe_coro is not None: + await maybe_coro except Exception as e: logger.exception( "Error while handling %s frame exception raised %s", frame_entry.__class__, e @@ -211,26 +213,30 @@ async def _listener(self) -> None: assert self._conn while True: try: - frame = await self._conn.read_frame() + if self._conn is not None: + frame = await self._conn.read_frame() except ConnectionClosed as e: - self._conn_is_closed = True - if self._connection_closed_handler is not None: + if self._connection_closed_handler is not None and self._conn is not None: connection_error_info = DisconnectionErrorInfo(e, self._streams) result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result else: - print("TCP connection closed") + logger.exception("TCP connection closed") + + self._conn = None break except socket.error as e: - if self._connection_closed_handler is not None: - connection_error_info = DisconnectionErrorInfo(e, self._streams) - result = self._connection_closed_handler(connection_error_info) - if result is not None and inspect.isawaitable(result): - await result - else: - print("TCP connection closed") + if self._conn is not None: + if self._connection_closed_handler is not None and self._conn is not None: + connection_error_info = DisconnectionErrorInfo(e, self._streams) + result = self._connection_closed_handler(connection_error_info) + if result is not None and inspect.isawaitable(result): + await result + else: + print("TCP connection closed") + self._conn = None break logger.debug("Received frame: %s", frame) @@ -241,7 +247,7 @@ async def _listener(self) -> None: fut.set_result(frame) del self._waiters[_key] - for subscriber_name, handler in self._handlers.get(frame.__class__, {}).items(): + for subscriber_name, handler in list(self._handlers.get(frame.__class__, {}).items()): try: if frame.__class__ == schema.Deliver: await self._frames[subscriber_name].put(frame) @@ -285,18 +291,17 @@ def is_started(self) -> bool: async def close(self) -> None: logger.info("Stopping client %s:%s", self.host, self.port) - if self._conn_is_closed is True: - return + if self._conn is not None: - if self.is_started: - await self.sync_request( - schema.Close( - self._corr_id_seq.next(), - code=1, - reason="OK", - ), - resp_schema=schema.CloseResponse, - ) + if self.is_started: + await self.sync_request( + schema.Close( + self._corr_id_seq.next(), + code=1, + reason="OK", + ), + resp_schema=schema.CloseResponse, + ) await self.stop_task("listener") diff --git a/rstream/consumer.py b/rstream/consumer.py index f9a6851..327c7df 100644 --- a/rstream/consumer.py +++ b/rstream/consumer.py @@ -117,8 +117,8 @@ async def close(self) -> None: self.stop() for subscriber in list(self._subscribers.values()): - await self.unsubscribe(subscriber.reference) - # await self.store_offset(subscriber.stream, subscriber.reference, subscriber.offset) + if subscriber.client.get_is_connection_active() is True: + await self.unsubscribe(subscriber.reference) self._subscribers.clear() @@ -337,5 +337,6 @@ async def stream(self, subscriber_name) -> str: return self._subscribers[subscriber_name].stream def get_stream(self, subscriber_name) -> str: - + if subscriber_name not in self._subscribers: + return "" return self._subscribers[subscriber_name].stream diff --git a/rstream/producer.py b/rstream/producer.py index 2423af9..ced388b 100644 --- a/rstream/producer.py +++ b/rstream/producer.py @@ -124,13 +124,17 @@ async def start(self) -> None: async def close(self) -> None: # flush messages still in buffer - if self.task is not None: - for stream in self._buffered_messages: - await self._publish_buffered_messages(stream) - self.task.cancel() + if self._default_client is None: + return + if self.default_client.get_is_connection_active() is True: + if self.task is not None: + for stream in self._buffered_messages: + await self._publish_buffered_messages(stream) + self.task.cancel() for publisher in self._publishers.values(): - await publisher.client.delete_publisher(publisher.id) + if publisher.client.get_is_connection_active() is True: + await publisher.client.delete_publisher(publisher.id) publisher.client.remove_handler(schema.PublishConfirm, publisher.reference) publisher.client.remove_handler(schema.PublishError, publisher.reference) @@ -393,7 +397,7 @@ async def _timer(self): async def _publish_buffered_messages(self, stream: str) -> None: if stream in self._clients: - if self._clients[stream].get_connection() is None: + if self._clients[stream].get_is_connection_active() is False: return async with self._buffered_messages_lock: