Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize publishing #59

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 69 additions & 49 deletions fastapi_websocket_pubsub/event_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi_websocket_rpc.utils import gen_uid


logger = get_logger('EventBroadcaster')
logger = get_logger("EventBroadcaster")


# Cross service broadcast consts
Expand All @@ -35,9 +35,14 @@ class EventBroadcasterContextManager:
Friend-like class of EventBroadcaster (accessing "protected" members )
"""

def __init__(self, event_broadcaster: "EventBroadcaster", listen: bool = True, share: bool = True) -> None:
def __init__(
self,
event_broadcaster: "EventBroadcaster",
listen: bool = True,
share: bool = True,
) -> None:
"""
Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster
Provide a context manager for an EventBroadcaster, managing if it listens to events coming from the broadcaster
and if it subscribes to the internal notifier to share its events with the broadcaster

Args:
Expand All @@ -48,15 +53,16 @@ def __init__(self, event_broadcaster: "EventBroadcaster", listen: bool = True, s
self._event_broadcaster = event_broadcaster
self._share: bool = share
self._listen: bool = listen
self._lock = asyncio.Lock()

async def __aenter__(self):
async with self._lock:
async with self._event_broadcaster._context_manager_lock:
if self._listen:
self._event_broadcaster._listen_count += 1
if self._event_broadcaster._listen_count == 1:
# We have our first listener start the read-task for it (And all those who'd follow)
logger.info("Listening for incoming events from broadcast channel (first listener started)")
logger.info(
"Listening for incoming events from broadcast channel (first listener started)"
)
# Start task listening on incoming broadcasts
self._event_broadcaster.start_reader_task()

Expand All @@ -66,15 +72,17 @@ async def __aenter__(self):
# We have our first publisher
# Init the broadcast used for sharing (reading has its own)
self._event_broadcaster._acquire_sharing_broadcast_channel()
logger.debug("Subscribing to ALL_TOPICS, and sharing messages with broadcast channel")
# Subscribe to internal events form our own event notifier and broadcast them
await self._event_broadcaster._subscribe_to_all_topics()
logger.debug(
"Subscribing to ALL_TOPICS, and sharing messages with broadcast channel"
)
else:
logger.debug(f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}")
logger.debug(
f"Did not subscribe to ALL_TOPICS: share count == {self._event_broadcaster._share_count}"
)
return self

async def __aexit__(self, exc_type, exc, tb):
async with self._lock:
async with self._event_broadcaster._context_manager_lock:
try:
if self._listen:
self._event_broadcaster._listen_count -= 1
Expand All @@ -87,12 +95,10 @@ async def __aexit__(self, exc_type, exc, tb):
self._event_broadcaster._subscription_task = None

if self._share:
self._event_broadcaster._share_count -= 1
self._event_broadcaster._share_count -= 1
# if this was last sharer - we can stop subscribing to internal events - we aren't sharing anymore
if self._event_broadcaster._share_count == 0:
# Unsubscribe from internal events
logger.debug("Unsubscribing from ALL TOPICS")
await self._event_broadcaster._unsubscribe_from_topics()
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why you dropped the _unsubscribe_from_topics ?


except:
logger.exception("Failed to exit EventBroadcaster context")
Expand All @@ -110,8 +116,14 @@ class EventBroadcaster:
<Your Code>
"""

def __init__(self, broadcast_url: str, notifier: EventNotifier, channel="EventNotifier",
broadcast_type=None, is_publish_only=False) -> None:
def __init__(
self,
broadcast_url: str,
notifier: EventNotifier,
channel="EventNotifier",
broadcast_type=None,
is_publish_only=False,
) -> None:
"""

Args:
Expand All @@ -138,26 +150,31 @@ def __init__(self, broadcast_url: str, notifier: EventNotifier, channel="EventNo
self._publish_lock = None
# used to track creation / removal of resources needed per type (reader task->listen, and subscription to internal events->share)
self._listen_count: int = 0
self._share_count: int = 0
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
self._share_count: int = 0
# If we opt to manage the context directly (i.e. call async with on the event broadcaster itself)
self._context_manager = None
self._context_manager_lock = asyncio.Lock()


async def __broadcast_notifications__(self, subscription: Subscription, data):
async def __broadcast_notifications__(self, topics: TopicList, data):
"""
Share incoming internal notifications with the entire broadcast channel

Args:
subscription (Subscription): the subscription that got triggered
data: the event data
"""
logger.info("Broadcasting incoming event: {}".format({'topic': subscription.topic, 'notifier_id': self._id}))
note = BroadcastNotification(notifier_id=self._id, topics=[
subscription.topic], data=data)
logger.info(
"Broadcasting incoming event: {}".format(
{"topics": topics, "notifier_id": self._id}
)
)
note = BroadcastNotification(notifier_id=self._id, topics=topics, data=data)
# Publish event to broadcast
async with self._publish_lock:
async with self._sharing_broadcast_channel:
await self._sharing_broadcast_channel.publish(self._channel, note.json())
await self._sharing_broadcast_channel.publish(
self._channel, note.json()
)

def _acquire_sharing_broadcast_channel(self):
"""
Expand All @@ -166,14 +183,6 @@ def _acquire_sharing_broadcast_channel(self):
self._publish_lock = asyncio.Lock()
self._sharing_broadcast_channel = self._broadcast_type(self._broadcast_url)

async def _subscribe_to_all_topics(self):
return await self._notifier.subscribe(self._id,
ALL_TOPICS,
self.__broadcast_notifications__)

async def _unsubscribe_from_topics(self):
return await self._notifier.unsubscribe(self._id)

def get_context(self, listen=True, share=True):
"""
Create a new context manager you can call 'async with' on, configuring the broadcaster for listening, sharing, or both.
Expand All @@ -183,16 +192,16 @@ def get_context(self, listen=True, share=True):
share (bool, optional): Should we share events with the broadcast channel. Defaults to True.

Returns:
EventBroadcasterContextManager: the context
EventBroadcasterContextManager: the context
"""
return EventBroadcasterContextManager(self, listen=listen, share=share)

def get_listening_context(self):
return EventBroadcasterContextManager(self, listen=True, share=False)

def get_sharing_context(self):
return EventBroadcasterContextManager(self, listen=False, share=True)

async def __aenter__(self):
"""
Convince caller (also backward compaltability)
Expand All @@ -201,7 +210,6 @@ async def __aenter__(self):
self._context_manager = self.get_context(listen=not self._is_publish_only)
return await self._context_manager.__aenter__()


async def __aexit__(self, exc_type, exc, tb):
await self._context_manager.__aexit__(exc_type, exc, tb)

Expand All @@ -215,14 +223,15 @@ def start_reader_task(self):
# Make sure a task wasn't started already
if self._subscription_task is not None:
# we already started a task for this worker process
logger.debug("No need for listen task, already started broadcast listen task for this notifier")
logger.debug(
"No need for listen task, already started broadcast listen task for this notifier"
)
return
# Trigger the task
logger.debug("Spawning broadcast listen task")
self._subscription_task = asyncio.create_task(
self.__read_notifications__())
self._subscription_task = asyncio.create_task(self.__read_notifications__())
return self._subscription_task

def get_reader_task(self):
return self._subscription_task

Expand All @@ -235,15 +244,26 @@ async def __read_notifications__(self):
listening_broadcast_channel = self._broadcast_type(self._broadcast_url)
async with listening_broadcast_channel:
# Subscribe to our channel
async with listening_broadcast_channel.subscribe(channel=self._channel) as subscriber:
async with listening_broadcast_channel.subscribe(
channel=self._channel
) as subscriber:
async for event in subscriber:
try:
notification = BroadcastNotification.parse_raw(
event.message)
# Avoid re-publishing our own broadcasts
if notification.notifier_id != self._id:
logger.info("Handling incoming broadcast event: {}".format({'topics': notification.topics, 'src': notification.notifier_id}))
# Notify subscribers of message received from broadcast
await self._notifier.notify(notification.topics, notification.data, notifier_id=self._id)
notification = BroadcastNotification.parse_raw(event.message)
logger.debug(
"Handling incoming broadcast event: {}".format(
{
"topics": notification.topics,
"src": notification.notifier_id,
}
)
)
# Notify subscribers of message received from broadcast
await self._notifier.notify(
notification.topics, notification.data, notifier_id=self._id
)
except:
logger.exception("Failed handling incoming broadcast")
logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems more like a debug message than info

"No more events to read from subscriber (underlying connection closed)"
)
47 changes: 35 additions & 12 deletions fastapi_websocket_pubsub/pub_sub_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
on_connect: List[Coroutine] = None,
on_disconnect: List[Coroutine] = None,
rpc_channel_get_remote_id: bool = False,
ignore_broadcaster_disconnected = True,
ignore_broadcaster_disconnected=True,
):
"""
The PubSub endpoint recives subscriptions from clients and publishes data back to them upon receiving relevant publications.
Expand Down Expand Up @@ -85,12 +85,33 @@ def __init__(
# Separate if for the server to subscribe to its own events
self._subscriber_id: str = self.notifier.gen_subscriber_id()
self._ignore_broadcaster_disconnected = ignore_broadcaster_disconnected
self._broadcaster_sharing_context = None

async def subscribe(
self, topics: Union[TopicList, ALL_TOPICS], callback: EventCallback
) -> List[Subscription]:
return await self.notifier.subscribe(self._subscriber_id, topics, callback)

async def _publish_through_broadcaster(
self, topics: Union[TopicList, Topic], data=None
):
if self._broadcaster_sharing_context is None:
logger.debug(f"Getting new broadcaster sharing context")
self._broadcaster_sharing_context = self.broadcaster.get_context(
listen=True, share=True
)
logger.debug(f"Acquiring broadcaster sharing context")
try:
async with self._broadcaster_sharing_context:
# We don't notify notifier here, broadcaster listens to its own back-channel and notifies notifier,
# Thus all subscribers are notified in the same order (local and remote)
await self.broadcaster.__broadcast_notifications__(topics, data)
except Exception:
# Could check if the exception has to do with disconnection, but just in case better to restart sharing context anyway
logger.warning(f"Exception in publish, resetting sharing context")
self._broadcaster_sharing_context = None
raise

async def publish(self, topics: Union[TopicList, Topic], data=None):
"""
Publish events to subscribres of given topics currently connected to the endpoint
Expand All @@ -99,15 +120,12 @@ async def publish(self, topics: Union[TopicList, Topic], data=None):
topics (Union[TopicList, Topic]): topics to publish to relevant subscribers
data (Any, optional): Event data to be passed to each subscriber. Defaults to None.
"""
# if we have a broadcaster make sure we share with it (no matter where this call comes from)
# sharing here means - the broadcaster listens in to the notifier as well
logger.debug(f"Publishing message to topics: {topics}")
if self.broadcaster is not None:
logger.debug(f"Acquiring broadcaster sharing context")
async with self.broadcaster.get_context(listen=False, share=True):
await self.notifier.notify(topics, data, notifier_id=self._id)
# otherwise just notify
# Use broadcaster if available, notifier would be notified when broadcaster reads its own notification
await self._publish_through_broadcaster(topics, data)
else:
# otherwise just notify
await self.notifier.notify(topics, data, notifier_id=self._id)

# canonical name (backward compatability)
Expand All @@ -132,14 +150,19 @@ async def main_loop(self, websocket: WebSocket, client_id: str = None, **kwargs)
async with self.broadcaster:
logger.debug("Entering endpoint's main loop with broadcaster")
if self._ignore_broadcaster_disconnected:
await self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
await self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
else:
main_loop_task = asyncio.create_task(
self.endpoint.main_loop(websocket, client_id=client_id, **kwargs)
self.endpoint.main_loop(
websocket, client_id=client_id, **kwargs
)
)
done, pending = await asyncio.wait(
[main_loop_task, self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED,
)
done, pending = await asyncio.wait([main_loop_task,
self.broadcaster.get_reader_task()],
return_when=asyncio.FIRST_COMPLETED)
logger.debug(f"task is done: {done}")
# broadcaster's reader task is used by other endpoints and shouldn't be cancelled
if main_loop_task in pending:
Expand Down
Loading