Skip to content

Commit

Permalink
Add internal queues for pubsub message transmission to clients.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaul Kremer committed Aug 23, 2023
1 parent 8c08e11 commit 8015b9d
Showing 1 changed file with 63 additions and 2 deletions.
65 changes: 63 additions & 2 deletions fastapi_websocket_pubsub/event_notifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import copy
from collections import OrderedDict
from time import monotonic_ns
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union

from fastapi_websocket_rpc import RpcChannel
Expand Down Expand Up @@ -37,7 +39,65 @@ class Subscription(BaseModel):

# Publish event callback signature
def EventCallback(subscription: Subscription, data: Any):
pass
...


SUBSCRIPTION_TASK_CLEANUP_TIME = 60


class SubscriptionPusher:
def __init__(self):
self._queues: Dict[Subscription, asyncio.Queue] = {}
self._tasks: Dict[Subscription, asyncio.Task] = {}
self._queue_flush_times: OrderedDict[Subscription, int] = OrderedDict()
self._queues_lock = asyncio.Lock()
self._cleanup_task = asyncio.create_task(self._cleanup_queues)

async def trigger(self, subscription: Subscription, data):
need_create_task = False
await self._queues_lock.acquire()
try:
if not subscription in self._queues:
self._queues[subscription] = asyncio.Queue()
need_create_task = True
finally:
self._queues_lock.release()
if need_create_task:
self._queue_flush_times[subscription] = monotonic_ns()
self._tasks[subscription] = asyncio.create_task(
self._handle_queue(subscription)
)
await self._queues[subscription].put(data)

async def _handle_queue(self, subscription: Subscription):
while True:
data = await self._queues[subscription].get()
try:
await subscription.callback(data)
except Exception:
logger.opt(exception=True).warning(
"Unable to handle subscription {} data {}:", subscription, data
)
if self._queues[subscription].empty():
self._queue_flush_times[subscription] = monotonic_ns()
self._queue_flush_times.move_to_end(subscription)

async def _cleanup_queues(self):
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
await self._queues_lock.acquire()
try:
subscriptions_to_delete = []
for subscription, last_flush in self._queue_flush_times.items():
if last_flush > most_recent_flush_to_delete:
# We're done
break
for subscription in subscriptions_to_delete:
self._tasks[subscription].cancel()
del self._tasks[subscription]
del self._queues[subscription]
del self._queue_flush_times[subscription]
finally:
self._queues_lock.release()


class EventNotifier:
Expand Down Expand Up @@ -70,6 +130,7 @@ def __init__(self):
self._on_unsubscribe_events = []
# List of restriction checks to perform on every action on the channel
self._channel_restrictions = []
self._subscription_pusher = SubscriptionPusher()

def gen_subscriber_id(self):
return gen_uid()
Expand Down Expand Up @@ -175,7 +236,7 @@ async def trigger_callback(
subscriber_id: SubscriberId,
subscription: Subscription,
):
await subscription.callback(subscription, data)
await self._subscription_pusher.trigger(subscription, data)

async def callback_subscribers(
self,
Expand Down

0 comments on commit 8015b9d

Please sign in to comment.