diff --git a/fastapi_websocket_pubsub/event_broadcaster.py b/fastapi_websocket_pubsub/event_broadcaster.py index f3e43fe..966ce7c 100644 --- a/fastapi_websocket_pubsub/event_broadcaster.py +++ b/fastapi_websocket_pubsub/event_broadcaster.py @@ -65,7 +65,7 @@ async def __aenter__(self): "Listening for incoming events from broadcast channel (first listener started)" ) # Start task listening on incoming broadcasts - self._event_broadcaster.start_reader_task() + await self._event_broadcaster.start_reader_task() if self._share: self._event_broadcaster._share_count += 1 @@ -159,6 +159,7 @@ def __init__( self._context_manager = None self._context_manager_lock = asyncio.Lock() self._tasks = set() + self.listening_broadcast_channel = None async def __broadcast_notifications__(self, subscription: Subscription, data): """ @@ -221,9 +222,12 @@ async def __aenter__(self): return await self._context_manager.__aenter__() async def __aexit__(self, exc_type, exc, tb): + if self.listening_broadcast_channel is not None: + await self.listening_broadcast_channel.disconnect() + self.listening_broadcast_channel = None await self._context_manager.__aexit__(exc_type, exc, tb) - def start_reader_task(self): + async def start_reader_task(self): """Spawn a task reading incoming broadcasts and posting them to the intreal notifier Raises: BroadcasterAlreadyStarted: if called more than once per context @@ -237,6 +241,20 @@ def start_reader_task(self): "No need for listen task, already started broadcast listen task for this notifier" ) return + + # Init new broadcast channel for reading + try: + if self.listening_broadcast_channel is None: + self.listening_broadcast_channel = self._broadcast_type( + self._broadcast_url + ) + await self.listening_broadcast_channel.connect() + except Exception as e: + logger.error( + f"Failed to connect to broadcast channel for reading incoming events: {e}" + ) + raise e + # Trigger the task logger.debug("Spawning broadcast listen task") self._subscription_task = asyncio.create_task(self.__read_notifications__()) @@ -249,44 +267,41 @@ async def __read_notifications__(self): """ read incoming broadcasts and posting them to the intreal notifier """ - logger.info("Starting broadcaster listener") - # Init new broadcast channel for reading - listening_broadcast_channel = self._broadcast_type(self._broadcast_url) - async with listening_broadcast_channel: - # Subscribe to our channel - async with listening_broadcast_channel.subscribe( - channel=self._channel - ) as subscriber: - async for event in subscriber: - try: - notification = BroadcastNotification.parse_raw(event.message) - # Avoid re-publishing our own broadcasts - if notification.notifier_id != self._id: - logger.debug( - "Handling incoming broadcast event: {}".format( - { - "topics": notification.topics, - "src": notification.notifier_id, - } - ) + logger.debug("Starting broadcaster listener") + # Subscribe to our channel + async with self.listening_broadcast_channel.subscribe( + channel=self._channel + ) as subscriber: + async for event in subscriber: + try: + notification = BroadcastNotification.parse_raw(event.message) + # Avoid re-publishing our own broadcasts + if notification.notifier_id != self._id: + logger.debug( + "Handling incoming broadcast event: {}".format( + { + "topics": notification.topics, + "src": notification.notifier_id, + } ) - # Notify subscribers of message received from broadcast - task = asyncio.create_task( - self._notifier.notify( - notification.topics, - notification.data, - notifier_id=self._id, - ) + ) + # Notify subscribers of message received from broadcast + task = asyncio.create_task( + self._notifier.notify( + notification.topics, + notification.data, + notifier_id=self._id, ) + ) - self._tasks.add(task) + self._tasks.add(task) - def cleanup(task): - self._tasks.remove(task) + def cleanup(task): + self._tasks.remove(task) - task.add_done_callback(cleanup) - except: - logger.exception("Failed handling incoming broadcast") - logger.info( - "No more events to read from subscriber (underlying connection closed)" - ) + task.add_done_callback(cleanup) + except: + logger.exception("Failed handling incoming broadcast") + logger.info( + "No more events to read from subscriber (underlying connection closed)" + )