Skip to content

Commit

Permalink
new modifications after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniele Palaia authored and Daniele Palaia committed Oct 12, 2023
1 parent 78f2b22 commit 77e517a
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import signal

from rstream import (
AMQPMessage,
Consumer,
DisconnectionErrorInfo,
MessageContext,
amqp_decoder,
)

STREAM = "my-test-stream"


async def consume():
async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:
print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)

# clean close or reconnect
await consumer.close()

consumer = Consumer(
host="localhost",
port=5552,
vhost="/",
username="guest",
password="guest",
connection_closed_handler=on_connection_closed,
)

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, lambda: asyncio.create_task(consumer.close()))

async def on_message(msg: AMQPMessage, message_context: MessageContext):
stream = message_context.consumer.get_stream(message_context.subscriber_name)
offset = message_context.offset
# print("Got message: {} from stream {}, offset {}".format(msg, stream, offset))

await consumer.start()
await consumer.subscribe(stream=STREAM, callback=on_message, decoder=amqp_decoder)
print("im here")
await consumer.run()


asyncio.run(consume())
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import time

from rstream import (
AMQPMessage,
DisconnectionErrorInfo,
Producer,
)

STREAM = "my-test-stream"
MESSAGES = 1000000
connection_is_closed = False


async def publish():
async def on_connection_closed(disconnection_info: DisconnectionErrorInfo) -> None:
print(
"connection has been closed from stream: "
+ str(disconnection_info.streams)
+ " for reason: "
+ str(disconnection_info.reason)
)

# clean close or reconnect
await producer.close()
global connection_is_closed
connection_is_closed = True

async with Producer(
"localhost", username="guest", password="guest", connection_closed_handler=on_connection_closed
) as producer:
# create a stream if it doesn't already exist
await producer.create_stream(STREAM, exists_ok=True)

# sending a million of messages in AMQP format
start_time = time.perf_counter()

for i in range(MESSAGES):
amqp_message = AMQPMessage(
body="hello: {}".format(i),
)
# send is asynchronous
if connection_is_closed is False:
await producer.send(stream=STREAM, message=amqp_message)
else:
break

end_time = time.perf_counter()
print(f"Sent {MESSAGES} messages in {end_time - start_time:0.4f} seconds")


asyncio.run(publish())
67 changes: 36 additions & 31 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(
self._is_not_closed: bool = True

self._streams: list[str] = []
self._conn_is_closed: bool = False

def start_task(self, name: str, coro: Awaitable[None]) -> None:
assert name not in self._tasks
Expand Down Expand Up @@ -136,18 +135,20 @@ def remove_handler(self, frame_cls: Type[FT], name: Optional[str] = None) -> Non
else:
self._handlers[frame_cls].clear()

def get_connection(self) -> Optional[Connection]:
return self._conn
def get_is_connection_active(self) -> bool:
if self._conn is None:
return False
return True

async def send_frame(self, frame: schema.Frame) -> None:
logger.debug("Sending frame: %s", frame)
assert self._conn
try:
await self._conn.write_frame(frame)
except socket.error as e:
self._conn_is_closed = True
self._conn = None
if self._connection_closed_handler is None:
print("TCP connection closed")
logger.exception("TCP connection closed")
else:
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
Expand Down Expand Up @@ -199,9 +200,10 @@ async def _run_delivery_handlers(self, subscriber_name: str, handler: HT[FT]):
while self._is_not_closed:
frame_entry = await self._frames[subscriber_name].get()
try:
maybe_coro = handler(frame_entry)
if maybe_coro is not None:
await maybe_coro
if self._conn is not None:
maybe_coro = handler(frame_entry)
if maybe_coro is not None:
await maybe_coro
except Exception as e:
logger.exception(
"Error while handling %s frame exception raised %s", frame_entry.__class__, e
Expand All @@ -211,26 +213,30 @@ async def _listener(self) -> None:
assert self._conn
while True:
try:
frame = await self._conn.read_frame()
if self._conn is not None:
frame = await self._conn.read_frame()
except ConnectionClosed as e:
self._conn_is_closed = True

if self._connection_closed_handler is not None:
if self._connection_closed_handler is not None and self._conn is not None:
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
print("TCP connection closed")
logger.exception("TCP connection closed")

self._conn = None
break
except socket.error as e:
if self._connection_closed_handler is not None:
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
print("TCP connection closed")
if self._conn is not None:
if self._connection_closed_handler is not None and self._conn is not None:
connection_error_info = DisconnectionErrorInfo(e, self._streams)
result = self._connection_closed_handler(connection_error_info)
if result is not None and inspect.isawaitable(result):
await result
else:
print("TCP connection closed")
self._conn = None
break

logger.debug("Received frame: %s", frame)
Expand All @@ -241,7 +247,7 @@ async def _listener(self) -> None:
fut.set_result(frame)
del self._waiters[_key]

for subscriber_name, handler in self._handlers.get(frame.__class__, {}).items():
for subscriber_name, handler in list(self._handlers.get(frame.__class__, {}).items()):
try:
if frame.__class__ == schema.Deliver:
await self._frames[subscriber_name].put(frame)
Expand Down Expand Up @@ -285,18 +291,17 @@ def is_started(self) -> bool:
async def close(self) -> None:
logger.info("Stopping client %s:%s", self.host, self.port)

if self._conn_is_closed is True:
return
if self._conn is not None:

if self.is_started:
await self.sync_request(
schema.Close(
self._corr_id_seq.next(),
code=1,
reason="OK",
),
resp_schema=schema.CloseResponse,
)
if self.is_started:
await self.sync_request(
schema.Close(
self._corr_id_seq.next(),
code=1,
reason="OK",
),
resp_schema=schema.CloseResponse,
)

await self.stop_task("listener")

Expand Down
7 changes: 4 additions & 3 deletions rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ async def close(self) -> None:
self.stop()

for subscriber in list(self._subscribers.values()):
await self.unsubscribe(subscriber.reference)
# await self.store_offset(subscriber.stream, subscriber.reference, subscriber.offset)
if subscriber.client.get_is_connection_active() is True:
await self.unsubscribe(subscriber.reference)

self._subscribers.clear()

Expand Down Expand Up @@ -337,5 +337,6 @@ async def stream(self, subscriber_name) -> str:
return self._subscribers[subscriber_name].stream

def get_stream(self, subscriber_name) -> str:

if subscriber_name not in self._subscribers:
return ""
return self._subscribers[subscriber_name].stream
16 changes: 10 additions & 6 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,17 @@ async def start(self) -> None:

async def close(self) -> None:
# flush messages still in buffer
if self.task is not None:
for stream in self._buffered_messages:
await self._publish_buffered_messages(stream)
self.task.cancel()
if self._default_client is None:
return
if self.default_client.get_is_connection_active() is True:
if self.task is not None:
for stream in self._buffered_messages:
await self._publish_buffered_messages(stream)
self.task.cancel()

for publisher in self._publishers.values():
await publisher.client.delete_publisher(publisher.id)
if publisher.client.get_is_connection_active() is True:
await publisher.client.delete_publisher(publisher.id)
publisher.client.remove_handler(schema.PublishConfirm, publisher.reference)
publisher.client.remove_handler(schema.PublishError, publisher.reference)

Expand Down Expand Up @@ -393,7 +397,7 @@ async def _timer(self):
async def _publish_buffered_messages(self, stream: str) -> None:

if stream in self._clients:
if self._clients[stream].get_connection() is None:
if self._clients[stream].get_is_connection_active() is False:
return

async with self._buffered_messages_lock:
Expand Down

0 comments on commit 77e517a

Please sign in to comment.