Skip to content

Commit

Permalink
Add internal queues for pubsub message transmission to clients.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaul Kremer committed Aug 23, 2023
1 parent 8c08e11 commit 43f84cb
Showing 1 changed file with 72 additions and 2 deletions.
74 changes: 72 additions & 2 deletions fastapi_websocket_pubsub/event_notifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import copy
from collections import OrderedDict
from time import monotonic_ns
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union

from fastapi_websocket_rpc import RpcChannel
Expand Down Expand Up @@ -37,7 +39,74 @@ class Subscription(BaseModel):

# Publish event callback signature
def EventCallback(subscription: Subscription, data: Any):
pass
...


SUBSCRIPTION_TASK_CLEANUP_TIME = 60


class SubscriptionPusher:
def __init__(self):
self._queues: Dict[str, asyncio.Queue] = {}
self._tasks: Dict[str, asyncio.Task] = {}
self._queue_flush_times: OrderedDict[str, int] = OrderedDict()
self._queues_lock = asyncio.Lock()
self._cleanup_task = asyncio.create_task(self._cleanup_queues)

async def trigger(self, subscription: Subscription, data):
need_create_task = False
await self._queues_lock.acquire()
try:
if not subscription.id in self._queues:
self._queues[subscription.id] = asyncio.Queue()
need_create_task = True
finally:
self._queues_lock.release()
if need_create_task:
self._queue_flush_times[subscription.id] = monotonic_ns()
self._tasks[subscription.id] = asyncio.create_task(
self._handle_queue(subscription.id)
)
await self._queues[subscription].put(data)

async def _handle_queue(self, subscription: Subscription):
while True:
data = await self._queues[subscription.id].get()
try:
await subscription.callback(data)
except Exception:
logger.opt(exception=True).warning(
"Unable to handle subscription {} data {}:", subscription, data
)
if self._queues[subscription.id].empty():
self._queue_flush_times[subscription.id] = monotonic_ns()
self._queue_flush_times.move_to_end(subscription.id)

async def _cleanup_queues(self):
while True:
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
await self._queues_lock.acquire()
try:
# This code is safe because there are no any awaits except
# for the lock, if this changes then we need to also lock
# the unlocked segment of trigger and _handle_queue with a
# per-subscription lock and hold it while we delete the
# subscription so we don't miss any messages
subscriptions_to_delete = []
for subscription_id, last_flush in self._queue_flush_times.items():
if last_flush > most_recent_flush_to_delete:
# We're done
break
if self._queues[subscription_id].empty():
subscriptions_to_delete.append(subscription_id)
for subscription_id in subscriptions_to_delete:
self._tasks[subscription_id].cancel()
del self._tasks[subscription_id]
del self._queues[subscription_id]
del self._queue_flush_times[subscription_id]
finally:
self._queues_lock.release()
await asyncio.sleep(1)


class EventNotifier:
Expand Down Expand Up @@ -70,6 +139,7 @@ def __init__(self):
self._on_unsubscribe_events = []
# List of restriction checks to perform on every action on the channel
self._channel_restrictions = []
self._subscription_pusher = SubscriptionPusher()

def gen_subscriber_id(self):
return gen_uid()
Expand Down Expand Up @@ -175,7 +245,7 @@ async def trigger_callback(
subscriber_id: SubscriberId,
subscription: Subscription,
):
await subscription.callback(subscription, data)
await self._subscription_pusher.trigger(subscription, data)

async def callback_subscribers(
self,
Expand Down

0 comments on commit 43f84cb

Please sign in to comment.