diff --git a/fastapi_websocket_pubsub/event_notifier.py b/fastapi_websocket_pubsub/event_notifier.py index 35cad9a..8c5a8a4 100644 --- a/fastapi_websocket_pubsub/event_notifier.py +++ b/fastapi_websocket_pubsub/event_notifier.py @@ -1,5 +1,7 @@ import asyncio import copy +from collections import OrderedDict +from time import monotonic_ns from typing import Any, Callable, Coroutine, Dict, List, Optional, Union from fastapi_websocket_rpc import RpcChannel @@ -37,7 +39,73 @@ class Subscription(BaseModel): # Publish event callback signature def EventCallback(subscription: Subscription, data: Any): - pass + ... + + +SUBSCRIPTION_TASK_CLEANUP_TIME = 60 + + +class SubscriptionPusher: + def __init__(self): + self._queues: Dict[str, asyncio.Queue] = {} + self._tasks: Dict[str, asyncio.Task] = {} + self._queue_flush_times: OrderedDict[str, int] = OrderedDict() + self._queues_lock = asyncio.Lock() + self._cleanup_task = asyncio.create_task(self._cleanup_queues()) + + async def trigger(self, subscription: Subscription, data): + need_create_task = False + await self._queues_lock.acquire() + try: + if not subscription.id in self._queues: + self._queues[subscription.id] = asyncio.Queue() + need_create_task = True + finally: + self._queues_lock.release() + if need_create_task: + self._queue_flush_times[subscription.id] = monotonic_ns() + self._tasks[subscription.id] = asyncio.create_task( + self._handle_queue(subscription) + ) + await self._queues[subscription.id].put(data) + + async def _handle_queue(self, subscription: Subscription): + while True: + try: + data = await self._queues[subscription.id].get() + try: + await subscription.callback(subscription, data) + except Exception: + logger.exception( + "Unable to handle subscription %r data %r:", subscription, data + ) + if self._queues[subscription.id].empty(): + self._queue_flush_times[subscription.id] = monotonic_ns() + self._queue_flush_times.move_to_end(subscription.id) + except Exception: + logger.exception("Handle failed:") + + async def _cleanup_queues(self): + while True: + most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME + await self._queues_lock.acquire() + try: + # This code is safe because there are no any awaits except for the lock, if this changes then we need to more locks + subscriptions_to_delete = [] + for subscription_id, last_flush in self._queue_flush_times.items(): + if last_flush > most_recent_flush_to_delete: + # We're done + break + if self._queues[subscription_id].empty(): + subscriptions_to_delete.append(subscription_id) + for subscription_id in subscriptions_to_delete: + self._tasks[subscription_id].cancel() + del self._tasks[subscription_id] + del self._queues[subscription_id] + del self._queue_flush_times[subscription_id] + finally: + self._queues_lock.release() + await asyncio.sleep(1) class EventNotifier: @@ -70,6 +138,7 @@ def __init__(self): self._on_unsubscribe_events = [] # List of restriction checks to perform on every action on the channel self._channel_restrictions = [] + self._subscription_pusher = SubscriptionPusher() def gen_subscriber_id(self): return gen_uid() @@ -175,7 +244,8 @@ async def trigger_callback( subscriber_id: SubscriberId, subscription: Subscription, ): - await subscription.callback(subscription, data) + logger.info("trigger on %s %s", subscription.id, topic) + await self._subscription_pusher.trigger(subscription, data) async def callback_subscribers( self,