Skip to content

Commit

Permalink
implementing code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Palaia authored and Daniele Palaia committed Oct 18, 2023
1 parent 77e517a commit fbc807f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
24 changes: 11 additions & 13 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,16 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non
else:
self._handlers[frame_cls].clear()

def get_is_connection_active(self) -> bool:
if self._conn is None:
return False
return True
def is_connection_alive(self) -> bool:
return self._is_not_closed

async def send_frame(self, frame: schema.Frame) -> None:
logger.debug("Sending frame: %s", frame)
assert self._conn
try:
await self._conn.write_frame(frame)
except socket.error as e:
self._conn = None
self._is_not_closed = False
if self._connection_closed_handler is None:
logger.exception("TCP connection closed")
else:
Expand Down Expand Up @@ -197,10 +195,10 @@ async def run_queue_listener_task(self, subscriber_name: str, handler: HT[FT]) -

async def _run_delivery_handlers(self, subscriber_name: str, handler: HT[FT]):

while self._is_not_closed:
while self.is_connection_alive():
frame_entry = await self._frames[subscriber_name].get()
try:
if self._conn is not None:
if self.is_connection_alive():
maybe_coro = handler(frame_entry)
if maybe_coro is not None:
await maybe_coro
Expand All @@ -213,30 +211,30 @@ async def _listener(self) -> None:
assert self._conn
while True:
try:
if self._conn is not None:
if self.is_connection_alive():
frame = await self._conn.read_frame()
except ConnectionClosed as e:

if self._connection_closed_handler is not None and self._conn is not None:
if self._connection_closed_handler is not None and self.is_connection_alive():
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
logger.exception("TCP connection closed")

self._conn = None
self._is_not_closed = False
break
except socket.error as e:
if self._conn is not None:
if self._connection_closed_handler is not None and self._conn is not None:
if self._connection_closed_handler is not None and self.is_connection_alive():
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
print("TCP connection closed")
self._conn = None
self._is_not_closed = False
break

logger.debug("Received frame: %s", frame)
Expand Down Expand Up @@ -291,7 +289,7 @@ def is_started(self) -> bool:
async def close(self) -> None:
logger.info("Stopping client %s:%s", self.host, self.port)

if self._conn is not None:
if self._conn is not None and self.is_connection_alive():

if self.is_started:
await self.sync_request(
Expand Down
2 changes: 1 addition & 1 deletion rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def close(self) -> None:
self.stop()

for subscriber in list(self._subscribers.values()):
if subscriber.client.get_is_connection_active() is True:
if subscriber.client.is_connection_alive():
await self.unsubscribe(subscriber.reference)

self._subscribers.clear()
Expand Down
30 changes: 24 additions & 6 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import inspect
import logging
import ssl
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -59,6 +60,9 @@ class ConfirmationStatus:
response_code: int = 0


logger = logging.getLogger(__name__)


class Producer:
def __init__(
self,
Expand Down Expand Up @@ -105,6 +109,7 @@ def __init__(
self._default_context_switch_counter = 0
self._default_context_switch_value = default_context_switch_value
self._connection_closed_handler = connection_closed_handler
self._close_called = False

@property
def default_client(self) -> Client:
Expand All @@ -120,21 +125,28 @@ async def __aexit__(self, *_: Any) -> None:
await self.close()

async def start(self) -> None:
self._close_called = False
self._default_client = await self._pool.get(connection_closed_handler=self._connection_closed_handler)

async def close(self) -> None:
self._close_called = True
# flush messages still in buffer
if self._default_client is None:
return
if self.default_client.get_is_connection_active() is True:
if self.default_client.is_connection_alive():
if self.task is not None:
for stream in self._buffered_messages:
await self._publish_buffered_messages(stream)
self.task.cancel()

for publisher in self._publishers.values():
if publisher.client.get_is_connection_active() is True:
await publisher.client.delete_publisher(publisher.id)
if publisher.client.is_connection_alive():
try:
await publisher.client.delete_publisher(publisher.id)
except asyncio.TimeoutError:
logger.debug("timeout when closing producer and deleting publisher")
except BaseException as exc:
logger.debug("delete_publisher in Producer.close:", exc)
publisher.client.remove_handler(schema.PublishConfirm, publisher.reference)
publisher.client.remove_handler(schema.PublishError, publisher.reference)

Expand Down Expand Up @@ -227,6 +239,9 @@ async def _send_batch(
if len(batch) == 0:
raise ValueError("Empty batch")

if self._close_called:
return []

messages = []
publishing_ids = set()
publishing_ids_callback: dict[CB[ConfirmationStatus], set[int]] = defaultdict(set)
Expand Down Expand Up @@ -389,15 +404,18 @@ async def send_sub_entry(
# After the timeout send the messages in _buffered_messages in batches
async def _timer(self):

while True:
while not self._close_called:
await asyncio.sleep(self._default_batch_publishing_delay)
for stream in self._buffered_messages:
await self._publish_buffered_messages(stream)
try:
await self._publish_buffered_messages(stream)
except BaseException as exc:
logger.debug("producer _timer exception: ", {exc})

async def _publish_buffered_messages(self, stream: str) -> None:

if stream in self._clients:
if self._clients[stream].get_is_connection_active() is False:
if self._clients[stream].is_connection_alive() is False:
return

async with self._buffered_messages_lock:
Expand Down

0 comments on commit fbc807f

Please sign in to comment.