From 35a472037eb3c13e8de6f870812e7670136eca1e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 09:21:44 -0700 Subject: [PATCH] Support using multiple/concurrent EventListeners --- griptape/events/event_bus.py | 22 ++++++++++++++-------- griptape/events/event_listener.py | 9 ++++----- tests/unit/events/test_event_listener.py | 23 ++++++++++++++++------- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 7060fd9edf..8d387147ab 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -1,8 +1,9 @@ from __future__ import annotations +import threading from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.mixins.singleton_mixin import SingletonMixin @@ -13,6 +14,7 @@ @define class _EventBus(SingletonMixin): _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + _thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()), alias="_thread_lock") @property def event_listeners(self) -> list[EventListener]: @@ -22,23 +24,27 @@ def add_event_listeners(self, event_listeners: list[EventListener]) -> list[Even return [self.add_event_listener(event_listener) for event_listener in event_listeners] def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: - for event_listener in event_listeners: - self.remove_event_listener(event_listener) + with self._thread_lock: + for event_listener in event_listeners: + self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self._event_listeners: - self._event_listeners.append(event_listener) + with self._thread_lock: + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def set_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: - self._event_listeners = event_listeners + with self._thread_lock: + self._event_listeners = event_listeners return self._event_listeners def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self._event_listeners: - self._event_listeners.remove(event_listener) + with self._thread_lock: + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self._event_listeners: diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index 0d5c661da0..4efb3745eb 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid from typing import TYPE_CHECKING, Callable, Optional from attrs import Factory, define, field @@ -12,6 +13,7 @@ @define class EventListener: + id: str = field(default=Factory(lambda: uuid.uuid4().hex), metadata={"serializable": True}, kw_only=True) handler: Callable[[BaseEvent], Optional[dict]] = field(default=Factory(lambda: lambda event: event.to_dict())) event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) @@ -21,17 +23,14 @@ class EventListener: def __enter__(self) -> EventListener: from griptape.events import EventBus - self._last_event_listeners = [*EventBus.event_listeners] - - EventBus.set_event_listeners([self]) + EventBus.add_event_listener(self) return self def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002 from griptape.events import EventBus - if self._last_event_listeners is not None: - EventBus.set_event_listeners(self._last_event_listeners) + EventBus.remove_event_listener(self) self._last_event_listeners = None diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index a9174ffde7..6af1213b16 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -94,8 +94,8 @@ def test_add_remove_event_listener(self, pipeline): mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, id="1", event_types=[StartPromptEvent])) event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) @@ -137,10 +137,19 @@ def event_handler(event: BaseEvent): mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False) def test_context_manager(self): - EventBus.add_event_listeners([EventListener()]) - last_event_listeners = EventBus.event_listeners + e1 = EventListener() + EventBus.add_event_listeners([e1]) - with EventListener() as e: - assert EventBus.event_listeners == [e] + with EventListener() as e2: + assert EventBus.event_listeners == [e1, e2] - assert EventBus.event_listeners == last_event_listeners + assert EventBus.event_listeners == [e1] + + def test_context_manager_multiple(self): + e1 = EventListener() + EventBus.add_event_listener(e1) + + with EventListener() as e2, EventListener() as e3: + assert EventBus.event_listeners == [e1, e2, e3] + + assert EventBus.event_listeners == [e1]