diff --git a/src/aact/nodes/base.py b/src/aact/nodes/base.py index 188e033..2daa8d0 100644 --- a/src/aact/nodes/base.py +++ b/src/aact/nodes/base.py @@ -1,3 +1,5 @@ +from asyncio import CancelledError +import logging import sys if sys.version_info >= (3, 11): @@ -17,6 +19,10 @@ OutputType = TypeVar("OutputType", covariant=True, bound=DataModel) +class NodeExitSignal(CancelledError): + """Node exit signal, which is raised in nodes' event handler. It is used to exit the node gracefully.""" + + class Node(BaseModel, Generic[InputType, OutputType]): input_channel_types: dict[str, Type[InputType]] output_channel_types: dict[str, Type[OutputType]] @@ -37,6 +43,7 @@ def __init__( self.r: Redis = Redis.from_url(redis_url) self.pubsub = self.r.pubsub() + self.logger = logging.getLogger("aact.nodes.base.Node") async def __aenter__(self) -> Self: try: @@ -67,12 +74,18 @@ async def _wait_for_input( async def event_loop( self, ) -> None: - async for input_channel, input_message in self._wait_for_input(): - async for output_channel, output_message in self.event_handler( - input_channel, input_message - ): - await self.r.publish(output_channel, output_message.model_dump_json()) - raise Exception("Event loop exited unexpectedly") + try: + async for input_channel, input_message in self._wait_for_input(): + async for output_channel, output_message in self.event_handler( + input_channel, input_message + ): + await self.r.publish( + output_channel, output_message.model_dump_json() + ) + except NodeExitSignal as e: + self.logger.info(f"Event loop cancelled: {e}. Exiting gracefully.") + except Exception as e: + raise e @abstractmethod async def event_handler(