From 0f193854cc4b4230301b22578d5e14b42c1e72f0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:17:34 -0700 Subject: [PATCH] Rename EventBus to event_bus --- CHANGELOG.md | 4 +-- .../drivers/event-listener-drivers.md | 24 +++++++-------- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../drivers/prompt/test_base_prompt_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 6 ++-- 23 files changed, 109 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6748299d0..7a95701c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index db02cd77a..c3c92cfe1 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,12 +14,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -84,12 +84,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -132,12 +132,12 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -179,11 +179,11 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -210,11 +210,11 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -242,11 +242,11 @@ from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 23ebcdc2a..b3f4a77fd 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.add_event_listeners([ +event_bus.add_event_listeners([ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 6ffd65550..a956f7deb 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -41,4 +41,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..d68457ebc 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,7 +28,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -257,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -269,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..d600c80a5 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..ade656f87 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..fd64a0f52 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -61,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a73b041f..e462ede90 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.events import EventBus +from griptape.events import event_bus @pytest.fixture(autouse=True) -def event_bus(): - EventBus.clear_event_listeners() +def mock_event_bus(): + event_bus.clear_event_listeners() - yield EventBus + yield event_bus - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..6fcab26e5 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..52b7d5c0d 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _EventBus +from griptape.events.event_bus import _event_bus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_EventBus, "publish_event") + mock_publish_event = mocker.patch.object(_event_bus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index d237bb3b4..7eb87036a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + event_bus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f3d9823d3..50763e0c3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.add_event_listeners( + event_bus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -89,25 +89,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d6e4da8b6..aa402bb48 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), @@ -118,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2