diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c18e59b04..99bd473c28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. - `AzureOpenAiTextToSpeechDriver`. +- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 3c4181aeef..beb02d66a8 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -73,6 +73,15 @@ Handler 1 ``` +## Context Managers + +You can also use [EventListener](../../reference/griptape/events/event_listener.md)s as a Python Context Manager. +The `EventListener` will automatically be added and removed from the [EventBus](../../reference/griptape/events/event_bus.md) when entering and exiting the context. + +```python +--8<-- "docs/griptape-framework/misc/src/events_context.py" +``` + ## Streaming diff --git a/docs/griptape-framework/misc/src/events_context.py b/docs/griptape-framework/misc/src/events_context.py new file mode 100644 index 0000000000..f9597ec153 --- /dev/null +++ b/docs/griptape-framework/misc/src/events_context.py @@ -0,0 +1,13 @@ +from griptape.events import EventBus, EventListener, FinishStructureRunEvent, StartPromptEvent +from griptape.structures import Agent + +EventBus.add_event_listeners( + [EventListener(lambda e: print(f"Out of context: {e.type}"), event_types=[StartPromptEvent])] +) + +agent = Agent(input="Hello!") + +with EventListener(lambda e: print(f"In context: {e.type}"), event_types=[FinishStructureRunEvent]): + agent.run() + +agent.run() diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 3ddc325ffa..e4b6518523 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -31,6 +31,11 @@ def add_event_listener(self, event_listener: EventListener) -> EventListener: return event_listener + def set_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: + self.clear_event_listeners() + + return self.add_event_listeners(event_listeners) + def remove_event_listener(self, event_listener: EventListener) -> None: if event_listener in self._event_listeners: self._event_listeners.remove(event_listener) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index 74171d3753..0d5c661da0 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -16,6 +16,25 @@ class EventListener: event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) + _last_event_listeners: Optional[list[EventListener]] = field(default=None) + + def __enter__(self) -> EventListener: + from griptape.events import EventBus + + self._last_event_listeners = [*EventBus.event_listeners] + + EventBus.set_event_listeners([self]) + + return self + + def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002 + from griptape.events import EventBus + + if self._last_event_listeners is not None: + EventBus.set_event_listeners(self._last_event_listeners) + + self._last_event_listeners = None + def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: event_types = self.event_types diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index cc432dafb4..084dd4a84c 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -33,6 +33,11 @@ def test_remove_event_listener(self): assert len(EventBus.event_listeners) == 0 + def test_set_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.set_event_listeners(listeners) + assert EventBus.event_listeners == listeners + def test_remove_unknown_event_listener(self): EventBus.remove_event_listener(EventListener()) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index a6d90d4fc9..a9174ffde7 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -135,3 +135,12 @@ def event_handler(event: BaseEvent): event_listener.publish_event(mock_event) mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False) + + def test_context_manager(self): + EventBus.add_event_listeners([EventListener()]) + last_event_listeners = EventBus.event_listeners + + with EventListener() as e: + assert EventBus.event_listeners == [e] + + assert EventBus.event_listeners == last_event_listeners