Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add internal queues for pubsub message transmission to clients. #63

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions fastapi_websocket_pubsub/event_notifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import copy
from collections import OrderedDict
from time import monotonic_ns
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union

from fastapi_websocket_rpc import RpcChannel
Expand Down Expand Up @@ -37,7 +39,73 @@ class Subscription(BaseModel):

# Publish event callback signature
def EventCallback(subscription: Subscription, data: Any):
pass
...


SUBSCRIPTION_TASK_CLEANUP_TIME = 60


class SubscriptionPusher:
def __init__(self):
self._queues: Dict[str, asyncio.Queue] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._queue_flush_times: OrderedDict[str, int] = OrderedDict()
self._queues_lock = asyncio.Lock()
self._cleanup_task = asyncio.create_task(self._cleanup_queues())

async def trigger(self, subscription: Subscription, data):
need_create_task = False
await self._queues_lock.acquire()
try:
if not subscription.id in self._queues:
self._queues[subscription.id] = asyncio.Queue()
need_create_task = True
finally:
self._queues_lock.release()
if need_create_task:
self._queue_flush_times[subscription.id] = monotonic_ns()
self._tasks[subscription.id] = asyncio.create_task(
self._handle_queue(subscription)
)
await self._queues[subscription.id].put(data)

async def _handle_queue(self, subscription: Subscription):
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should catch CancelledError and quit the while

try:
data = await self._queues[subscription.id].get()
try:
await subscription.callback(subscription, data)
except Exception:
logger.exception(
"Unable to handle subscription %r data %r:", subscription, data
)
if self._queues[subscription.id].empty():
self._queue_flush_times[subscription.id] = monotonic_ns()
self._queue_flush_times.move_to_end(subscription.id)
except Exception:
logger.exception("Handle failed:")

async def _cleanup_queues(self):
while True:
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
await self._queues_lock.acquire()
try:
# This code is safe because there are no any awaits except for the lock, if this changes then we need to more locks
subscriptions_to_delete = []
for subscription_id, last_flush in self._queue_flush_times.items():
if last_flush > most_recent_flush_to_delete:
# We're done
break
if self._queues[subscription_id].empty():
subscriptions_to_delete.append(subscription_id)
for subscription_id in subscriptions_to_delete:
self._tasks[subscription_id].cancel()
del self._tasks[subscription_id]
del self._queues[subscription_id]
del self._queue_flush_times[subscription_id]
finally:
self._queues_lock.release()
await asyncio.sleep(1)


class EventNotifier:
Expand Down Expand Up @@ -70,6 +138,7 @@ def __init__(self):
self._on_unsubscribe_events = []
# List of restriction checks to perform on every action on the channel
self._channel_restrictions = []
self._subscription_pusher = SubscriptionPusher()

def gen_subscriber_id(self):
return gen_uid()
Expand Down Expand Up @@ -175,7 +244,8 @@ async def trigger_callback(
subscriber_id: SubscriberId,
subscription: Subscription,
):
await subscription.callback(subscription, data)
logger.info("trigger on %s %s", subscription.id, topic)
await self._subscription_pusher.trigger(subscription, data)

async def callback_subscribers(
self,
Expand Down
Loading