diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..6ffd65550 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -10,7 +10,11 @@ @define class _EventBus: - event_listeners: list[EventListener] = field(factory=list, kw_only=True) + _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + + @property + def event_listeners(self) -> list[EventListener]: + return self._event_listeners def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -20,18 +24,21 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self.event_listeners: - self.event_listeners.append(event_listener) + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self.event_listeners: - self.event_listeners.remove(event_listener) + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self.event_listeners: + for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def clear_event_listeners(self) -> None: + self._event_listeners.clear() + EventBus = _EventBus() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7a73b041f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) def event_bus(): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() yield EventBus - EventBus.event_listeners = [] + EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..d237bb3b4 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -35,7 +35,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..f3d9823d3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,17 +59,19 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), - ] + EventBus.add_event_listeners( + [ + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + ] + ) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -87,7 +89,7 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d6e4da8b6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(),