Skip to content

Commit

Permalink
Add ability to use EventListener as Context Manager (#1163)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Sep 11, 2024
1 parent 9d9b643 commit 4bf3d57
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.
- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
9 changes: 9 additions & 0 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ Handler 1 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunEvent'>
```

## Context Managers

You can also use [EventListener](../../reference/griptape/events/event_listener.md)s as a Python Context Manager.
The `EventListener` will automatically be added and removed from the [EventBus](../../reference/griptape/events/event_bus.md) when entering and exiting the context.

```python
--8<-- "docs/griptape-framework/misc/src/events_context.py"
```

## Streaming


Expand Down
13 changes: 13 additions & 0 deletions docs/griptape-framework/misc/src/events_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from griptape.events import EventBus, EventListener, FinishStructureRunEvent, StartPromptEvent
from griptape.structures import Agent

EventBus.add_event_listeners(
[EventListener(lambda e: print(f"Out of context: {e.type}"), event_types=[StartPromptEvent])]
)

agent = Agent(input="Hello!")

with EventListener(lambda e: print(f"In context: {e.type}"), event_types=[FinishStructureRunEvent]):
agent.run()

agent.run()
17 changes: 11 additions & 6 deletions griptape/events/event_bus.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING

from attrs import define, field
from attrs import Factory, define, field

from griptape.mixins.singleton_mixin import SingletonMixin

Expand All @@ -13,6 +14,7 @@
@define
class _EventBus(SingletonMixin):
_event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners")
_thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()), alias="_thread_lock")

@property
def event_listeners(self) -> list[EventListener]:
Expand All @@ -26,21 +28,24 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None:
self.remove_event_listener(event_listener)

def add_event_listener(self, event_listener: EventListener) -> EventListener:
if event_listener not in self._event_listeners:
self._event_listeners.append(event_listener)
with self._thread_lock:
if event_listener not in self._event_listeners:
self._event_listeners.append(event_listener)

return event_listener

def remove_event_listener(self, event_listener: EventListener) -> None:
if event_listener in self._event_listeners:
self._event_listeners.remove(event_listener)
with self._thread_lock:
if event_listener in self._event_listeners:
self._event_listeners.remove(event_listener)

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
for event_listener in self._event_listeners:
event_listener.publish_event(event, flush=flush)

def clear_event_listeners(self) -> None:
self._event_listeners.clear()
with self._thread_lock:
self._event_listeners.clear()


EventBus = _EventBus()
16 changes: 16 additions & 0 deletions griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ class EventListener:
event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True)
driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True)

_last_event_listeners: Optional[list[EventListener]] = field(default=None)

def __enter__(self) -> EventListener:
from griptape.events import EventBus

EventBus.add_event_listener(self)

return self

def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002
from griptape.events import EventBus

EventBus.remove_event_listener(self)

self._last_event_listeners = None

def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None:
event_types = self.event_types

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/events/test_event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,21 @@ def event_handler(event: BaseEvent):
event_listener.publish_event(mock_event)

mock_event_listener_driver.publish_event.assert_called_once_with({"event": mock_event.to_dict()}, flush=False)

def test_context_manager(self):
e1 = EventListener()
EventBus.add_event_listeners([e1])

with EventListener() as e2:
assert EventBus.event_listeners == [e1, e2]

assert EventBus.event_listeners == [e1]

def test_context_manager_multiple(self):
e1 = EventListener()
EventBus.add_event_listener(e1)

with EventListener() as e2, EventListener() as e3:
assert EventBus.event_listeners == [e1, e2, e3]

assert EventBus.event_listeners == [e1]

0 comments on commit 4bf3d57

Please sign in to comment.