diff --git a/docs/examples/check_connection_broken/consumer_handle_connections_issues.py b/docs/examples/check_connection_broken/consumer_handle_connections_issues.py index a0ae251..91960f2 100644 --- a/docs/examples/check_connection_broken/consumer_handle_connections_issues.py +++ b/docs/examples/check_connection_broken/consumer_handle_connections_issues.py @@ -4,6 +4,7 @@ from rstream import ( AMQPMessage, Consumer, + DisconnectionErrorInfo, MessageContext, amqp_decoder, ) @@ -11,8 +12,13 @@ STREAM = "my-test-stream" -async def on_connection_closed(reason: Exception) -> None: - print("connection has been closed for reason: " + str(reason)) +async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) async def consume(): diff --git a/docs/examples/check_connection_broken/producer_handle_connections_issues.py b/docs/examples/check_connection_broken/producer_handle_connections_issues.py index 2a29fc1..271c886 100644 --- a/docs/examples/check_connection_broken/producer_handle_connections_issues.py +++ b/docs/examples/check_connection_broken/producer_handle_connections_issues.py @@ -1,14 +1,23 @@ import asyncio import time -from rstream import AMQPMessage, Producer +from rstream import ( + AMQPMessage, + DisconnectionErrorInfo, + Producer, +) STREAM = "my-test-stream" MESSAGES = 1000000 -async def on_connection_closed(reason: Exception) -> None: - print("connection has been closed for reason: " + str(reason)) +async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) async def publish(): diff --git a/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py b/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py new file mode 100644 index 0000000..e236bec --- /dev/null +++ b/docs/examples/check_connection_broken/superstream_consumer_handle_connections_issues.py @@ -0,0 +1,53 @@ +import asyncio +import signal + +from rstream import ( + AMQPMessage, + ConsumerOffsetSpecification, + DisconnectionErrorInfo, + MessageContext, + OffsetType, + SuperStreamConsumer, + amqp_decoder, +) + +cont = 0 + + +async def on_message(msg: AMQPMessage, message_context: MessageContext): + stream = await message_context.consumer.stream(message_context.subscriber_name) + offset = message_context.offset + print("Received message: {} from stream: {} - message offset: {}".format(msg, stream, offset)) + + +async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) + + +async def consume(): + consumer = SuperStreamConsumer( + host="localhost", + port=5552, + vhost="/", + username="guest", + password="guest", + super_stream="test_super_stream", + connection_closed_handler=on_connection_closed, + ) + + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close())) + offset_specification = ConsumerOffsetSpecification(OffsetType.FIRST, None) + await consumer.start() + await consumer.subscribe( + callback=on_message, decoder=amqp_decoder, offset_specification=offset_specification + ) + await consumer.run() + + +asyncio.run(consume()) diff --git a/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py b/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py new file mode 100644 index 0000000..629e5b6 --- /dev/null +++ b/docs/examples/check_connection_broken/superstream_producer_handle_connections_issues.py @@ -0,0 +1,54 @@ +import asyncio +import time + +from rstream import ( + AMQPMessage, + DisconnectionErrorInfo, + RouteType, + SuperStreamProducer, +) + +SUPER_STREAM = "test_super_stream" +MESSAGES = 10000000 + + +async def publish(): + # this value will be hashed using mumh3 hashing algorithm to decide the partition resolution for the message + async def routing_extractor(message: AMQPMessage) -> str: + return message.application_properties["id"] + + async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None: + + print( + "connection has been closed from stream: " + + str(disconnection_info.streams) + + " for reason: " + + str(disconnection_info.reason) + ) + + async with SuperStreamProducer( + "localhost", + username="guest", + password="guest", + routing_extractor=routing_extractor, + routing=RouteType.Hash, + connection_closed_handler=on_connection_closed, + super_stream=SUPER_STREAM, + ) as super_stream_producer: + + # sending a million of messages in AMQP format + start_time = time.perf_counter() + + for i in range(MESSAGES): + amqp_message = AMQPMessage( + body="hello: {}".format(i), + application_properties={"id": "{}".format(i)}, + ) + # send is asynchronous + await super_stream_producer.send(message=amqp_message) + + end_time = time.perf_counter() + print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds") + + +asyncio.run(publish()) diff --git a/rstream/__init__.py b/rstream/__init__.py index 4742d2d..c8d8cc7 100644 --- a/rstream/__init__.py +++ b/rstream/__init__.py @@ -3,6 +3,8 @@ from importlib import metadata +from .utils import DisconnectionErrorInfo + try: __version__ = metadata.version(__package__) __license__ = metadata.metadata(__package__)["license"] @@ -59,4 +61,5 @@ "StreamDoesNotExist", "OffsetSpecification", "EventContext", + "DisconnectionErrorInfo", ] diff --git a/rstream/client.py b/rstream/client.py index 4042f50..f44cc09 100644 --- a/rstream/client.py +++ b/rstream/client.py @@ -33,6 +33,7 @@ ) from .connection import Connection, ConnectionClosed from .schema import OffsetSpecification +from .utils import DisconnectionErrorInfo FT = TypeVar("FT", bound=schema.Frame) HT = Annotated[ @@ -66,7 +67,7 @@ def __init__( ssl_context: Optional[ssl.SSLContext] = None, frame_max: int, heartbeat: int, - connection_closed_handler: Optional[CB[Exception]] = None, + connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None, ): self.host = host self.port = port @@ -93,9 +94,13 @@ def __init__( self._last_heartbeat: float = 0 self._connection_closed_handler = connection_closed_handler + self._frames: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) self._is_not_closed: bool = True + self._streams: list[str] = [] + self._conn_is_closed: bool = False + def start_task(self, name: str, coro: Awaitable[None]) -> None: assert name not in self._tasks task = self._tasks[name] = asyncio.create_task(coro) @@ -131,16 +136,21 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non else: self._handlers[frame_cls].clear() + def get_connection(self) -> Optional[Connection]: + return self._conn + async def send_frame(self, frame: schema.Frame) -> None: logger.debug("Sending frame: %s", frame) assert self._conn try: await self._conn.write_frame(frame) except socket.error as e: + self._conn_is_closed = True if self._connection_closed_handler is None: print("TCP connection closed") else: - result = self._connection_closed_handler(e) + connection_error_info = DisconnectionErrorInfo(e, self._streams) + result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result @@ -203,8 +213,11 @@ async def _listener(self) -> None: try: frame = await self._conn.read_frame() except ConnectionClosed as e: + self._conn_is_closed = True + if self._connection_closed_handler is not None: - result = self._connection_closed_handler(e) + connection_error_info = DisconnectionErrorInfo(e, self._streams) + result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result else: @@ -212,7 +225,8 @@ async def _listener(self) -> None: break except socket.error as e: if self._connection_closed_handler is not None: - result = self._connection_closed_handler(e) + connection_error_info = DisconnectionErrorInfo(e, self._streams) + result = self._connection_closed_handler(connection_error_info) if result is not None and inspect.isawaitable(result): await result else: @@ -235,6 +249,7 @@ async def _listener(self) -> None: maybe_coro = handler(frame) if maybe_coro is not None: await maybe_coro + except Exception: logger.exception("Error while running handler %s of frame %s", handler, frame) @@ -270,7 +285,7 @@ def is_started(self) -> bool: async def close(self) -> None: logger.info("Stopping client %s:%s", self.host, self.port) - if self._conn is None: + if self._conn_is_closed is True: return if self.is_started: @@ -290,8 +305,9 @@ async def close(self) -> None: for subscriber_name in self._frames: await self.stop_task(f"run_delivery_handlers_{subscriber_name}") - await self._conn.close() - self._conn = None + if self._conn is not None: + await self._conn.close() + self._conn = None self.server_properties = None self._tasks.clear() @@ -417,6 +433,7 @@ async def query_leader_and_replicas( assert len(metadata_resp.metadata) == 1 metadata = metadata_resp.metadata[0] assert metadata.name == stream + self._streams.append(stream) brokers = {broker.reference: broker for broker in metadata_resp.brokers} leader = brokers[metadata.leader_ref] @@ -494,6 +511,8 @@ async def declare_publisher(self, stream: str, reference: str, publisher_id: int ) async def delete_publisher(self, publisher_id: int) -> None: + if self._conn is None: + return await self.sync_request( schema.DeletePublisher( self._corr_id_seq.next(), @@ -584,7 +603,9 @@ def __init__( self._clients: dict[Addr, Client] = {} async def get( - self, addr: Optional[Addr] = None, connection_closed_handler: Optional[CB[Exception]] = None + self, + addr: Optional[Addr] = None, + connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None, ) -> Client: """Get a client according to `addr` parameter @@ -610,7 +631,7 @@ async def get( return self._clients[desired_addr] async def _resolve_broker( - self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None + self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None ) -> Client: desired_host, desired_port = addr.host, str(addr.port) @@ -636,7 +657,9 @@ async def _resolve_broker( f"Failed to connect to {desired_host}:{desired_port} after {self.max_retries} tries" ) - async def new(self, addr: Addr, connection_closed_handler: Optional[CB[Exception]] = None) -> Client: + async def new( + self, addr: Addr, connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None + ) -> Client: host, port = addr client = Client( host=host, diff --git a/rstream/consumer.py b/rstream/consumer.py index 9a7293c..f9a6851 100644 --- a/rstream/consumer.py +++ b/rstream/consumer.py @@ -24,6 +24,7 @@ OffsetType, ) from .schema import OffsetSpecification +from .utils import DisconnectionErrorInfo MT = TypeVar("MT") CB = Annotated[Callable[[MT, Any], Union[None, Awaitable[None]]], "Message callback type"] @@ -71,7 +72,7 @@ def __init__( heartbeat: int = 60, load_balancer_mode: bool = False, max_retries: int = 20, - connection_closed_handler: Optional[CB_CONN[Exception]] = None, + connection_closed_handler: Optional[CB_CONN[DisconnectionErrorInfo]] = None, ): self._pool = ClientPool( host, @@ -331,7 +332,8 @@ async def stream_exists(self, stream: str) -> bool: return await self.default_client.stream_exists(stream) async def stream(self, subscriber_name) -> str: - + if subscriber_name not in self._subscribers: + return "" return self._subscribers[subscriber_name].stream def get_stream(self, subscriber_name) -> str: diff --git a/rstream/producer.py b/rstream/producer.py index 9e89cf4..2423af9 100644 --- a/rstream/producer.py +++ b/rstream/producer.py @@ -29,7 +29,7 @@ CompressionType, ICompressionCodec, ) -from .utils import RawMessage +from .utils import DisconnectionErrorInfo, RawMessage MessageT = TypeVar("MessageT", _MessageProtocol, bytes) MT = TypeVar("MT") @@ -75,7 +75,7 @@ def __init__( max_retries: int = 20, default_batch_publishing_delay: float = 0.2, default_context_switch_value: int = 1000, - connection_closed_handler: Optional[CB[Exception]] = None, + connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None, ): self._pool = ClientPool( host, @@ -125,7 +125,6 @@ async def start(self) -> None: async def close(self) -> None: # flush messages still in buffer if self.task is not None: - for stream in self._buffered_messages: await self._publish_buffered_messages(stream) self.task.cancel() @@ -393,6 +392,10 @@ async def _timer(self): async def _publish_buffered_messages(self, stream: str) -> None: + if stream in self._clients: + if self._clients[stream].get_connection() is None: + return + async with self._buffered_messages_lock: if len(self._buffered_messages[stream]): await self._send_batch(stream, self._buffered_messages[stream], sync=False) diff --git a/rstream/superstream_consumer.py b/rstream/superstream_consumer.py index 0766286..9ba8d00 100644 --- a/rstream/superstream_consumer.py +++ b/rstream/superstream_consumer.py @@ -25,6 +25,7 @@ ) from .consumer import Consumer, EventContext, MessageContext from .superstream import DefaultSuperstreamMetadata +from .utils import DisconnectionErrorInfo MT = TypeVar("MT") CB = Annotated[Callable[[MT], Union[None, Awaitable[None]]], "Message callback type"] @@ -45,7 +46,7 @@ def __init__( load_balancer_mode: bool = False, max_retries: int = 20, super_stream: str, - connection_closed_handler: Optional[CB[Exception]] = None, + connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None, ): self._pool = ClientPool( host, @@ -171,6 +172,7 @@ async def _create_consumer(self) -> Consumer: heartbeat=self.heartbeat, load_balancer_mode=False, max_retries=self.max_retries, + connection_closed_handler=self._connection_closed_handler, ) await consumer.start() diff --git a/rstream/superstream_producer.py b/rstream/superstream_producer.py index 5421f75..00acae5 100644 --- a/rstream/superstream_producer.py +++ b/rstream/superstream_producer.py @@ -21,6 +21,7 @@ RoutingKeyRoutingStrategy, RoutingStrategy, ) +from .utils import DisconnectionErrorInfo MT = TypeVar("MT") CB = Annotated[Callable[[MT], Awaitable[Any]], "Message callback type"] @@ -51,7 +52,7 @@ def __init__( load_balancer_mode: bool = False, max_retries: int = 20, default_batch_publishing_delay: float = 0.2, - connection_closed_handler: Optional[CB[Exception]] = None + connection_closed_handler: Optional[CB[DisconnectionErrorInfo]] = None ): self._pool = ClientPool( host, @@ -97,6 +98,7 @@ async def _get_producer(self) -> Producer: heartbeat=self.heartbeat, load_balancer_mode=self.load_balancer_mode, default_batch_publishing_delay=self.default_batch_publishing_delay, + connection_closed_handler=self._connection_closed_handler, ) await producer.start() self._producer = producer diff --git a/rstream/utils.py b/rstream/utils.py index 9267e19..7c9a8cc 100644 --- a/rstream/utils.py +++ b/rstream/utils.py @@ -39,3 +39,9 @@ async def _wait(self) -> Any: def __await__(self) -> Generator[Any, None, Any]: return self._wait().__await__() + + +@dataclass +class DisconnectionErrorInfo: + reason: Exception + streams: list[str]