From fbc807f2b11afeb921eeb22eb99ea74ad8100fa8 Mon Sep 17 00:00:00 2001 From: Daniele Palaia Date: Tue, 17 Oct 2023 10:16:29 +0200 Subject: [PATCH] implementing code review suggestions --- rstream/client.py | 24 +++++++++++------------- rstream/consumer.py | 2 +- rstream/producer.py | 30 ++++++++++++++++++++++++------ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/rstream/client.py b/rstream/client.py index e92385c..56233c7 100644 --- a/rstream/client.py +++ b/rstream/client.py @@ -135,10 +135,8 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non else: self._handlers[frame_cls].clear() - def get_is_connection_active(self) -> bool: - if self._conn is None: - return False - return True + def is_connection_alive(self) -> bool: + return self._is_not_closed async def send_frame(self, frame: schema.Frame) -> None: logger.debug("Sending frame: %s", frame) @@ -146,7 +144,7 @@ async def send_frame(self, frame: schema.Frame) -> None: try: await self._conn.write_frame(frame) except socket.error as e: - self._conn = None + self._is_not_closed = False if self._connection_closed_handler is None: logger.exception("TCP connection closed") else: @@ -197,10 +195,10 @@ async def run_queue_listener_task(self, subscriber_name: str, handler: HT[FT]) - async def _run_delivery_handlers(self, subscriber_name: str, handler: HT[FT]): - while self._is_not_closed: + while self.is_connection_alive(): frame_entry = await self._frames[subscriber_name].get() try: - if self._conn is not None: + if self.is_connection_alive(): maybe_coro = handler(frame_entry) if maybe_coro is not None: await maybe_coro @@ -213,11 +211,11 @@ async def _listener(self) -> None: assert self._conn while True: try: - if self._conn is not None: + if self.is_connection_alive(): frame = await self._conn.read_frame() except ConnectionClosed as e: - if self._connection_closed_handler is not None and self._conn is not None: + if self._connection_closed_handler is not None and self.is_connection_alive(): connection_error_info = DisconnectionErrorInfo(e, self._streams) result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): @@ -225,18 +223,18 @@ async def _listener(self) -> None: else: logger.exception("TCP connection closed") - self._conn = None + self._is_not_closed = False break except socket.error as e: if self._conn is not None: - if self._connection_closed_handler is not None and self._conn is not None: + if self._connection_closed_handler is not None and self.is_connection_alive(): connection_error_info = DisconnectionErrorInfo(e, self._streams) result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result else: print("TCP connection closed") - self._conn = None + self._is_not_closed = False break logger.debug("Received frame: %s", frame) @@ -291,7 +289,7 @@ def is_started(self) -> bool: async def close(self) -> None: logger.info("Stopping client %s:%s", self.host, self.port) - if self._conn is not None: + if self._conn is not None and self.is_connection_alive(): if self.is_started: await self.sync_request( diff --git a/rstream/consumer.py b/rstream/consumer.py index 327c7df..11c8ef4 100644 --- a/rstream/consumer.py +++ b/rstream/consumer.py @@ -117,7 +117,7 @@ async def close(self) -> None: self.stop() for subscriber in list(self._subscribers.values()): - if subscriber.client.get_is_connection_active() is True: + if subscriber.client.is_connection_alive(): await self.unsubscribe(subscriber.reference) self._subscribers.clear() diff --git a/rstream/producer.py b/rstream/producer.py index ced388b..1904cac 100644 --- a/rstream/producer.py +++ b/rstream/producer.py @@ -5,6 +5,7 @@ import asyncio import inspect +import logging import ssl from collections import defaultdict from dataclasses import dataclass @@ -59,6 +60,9 @@ class ConfirmationStatus: response_code: int = 0 +logger = logging.getLogger(__name__) + + class Producer: def __init__( self, @@ -105,6 +109,7 @@ def __init__( self._default_context_switch_counter = 0 self._default_context_switch_value = default_context_switch_value self._connection_closed_handler = connection_closed_handler + self._close_called = False @property def default_client(self) -> Client: @@ -120,21 +125,28 @@ async def __aexit__(self, *_: Any) -> None: await self.close() async def start(self) -> None: + self._close_called = False self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler) async def close(self) -> None: + self._close_called = True # flush messages still in buffer if self._default_client is None: return - if self.default_client.get_is_connection_active() is True: + if self.default_client.is_connection_alive(): if self.task is not None: for stream in self._buffered_messages: await self._publish_buffered_messages(stream) self.task.cancel() for publisher in self._publishers.values(): - if publisher.client.get_is_connection_active() is True: - await publisher.client.delete_publisher(publisher.id) + if publisher.client.is_connection_alive(): + try: + await publisher.client.delete_publisher(publisher.id) + except asyncio.TimeoutError: + logger.debug("timeout when closing producer and deleting publisher") + except BaseException as exc: + logger.debug("delete_publisher in Producer.close:", exc) publisher.client.remove_handler(schema.PublishConfirm, publisher.reference) publisher.client.remove_handler(schema.PublishError, publisher.reference) @@ -227,6 +239,9 @@ async def _send_batch( if len(batch) == 0: raise ValueError("Empty batch") + if self._close_called: + return [] + messages = [] publishing_ids = set() publishing_ids_callback: dict[CB[ConfirmationStatus], set[int]] = defaultdict(set) @@ -389,15 +404,18 @@ async def send_sub_entry( # After the timeout send the messages in _buffered_messages in batches async def _timer(self): - while True: + while not self._close_called: await asyncio.sleep(self._default_batch_publishing_delay) for stream in self._buffered_messages: - await self._publish_buffered_messages(stream) + try: + await self._publish_buffered_messages(stream) + except BaseException as exc: + logger.debug("producer _timer exception: ", {exc}) async def _publish_buffered_messages(self, stream: str) -> None: if stream in self._clients: - if self._clients[stream].get_is_connection_active() is False: + if self._clients[stream].is_connection_alive() is False: return async with self._buffered_messages_lock: