From adf4836e0ee9c79da8f073a84716cb872dd7c519 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 01/40] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3582ec02c..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From ae20c82af62747b04ec1053a0674dadb10db8fda Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 5 Aug 2024 17:31:12 -0700 Subject: [PATCH 02/40] WIP --- griptape/config/__init__.py | 2 ++ griptape/config/config.py | 3 ++ ...zon_dynamodb_conversation_memory_driver.py | 5 ++- .../local_conversation_memory_driver.py | 8 +++-- .../redis_conversation_memory_driver.py | 5 ++- .../audio/audio_transcription_engine.py | 7 +++-- .../engines/audio/text_to_speech_engine.py | 8 +++-- .../extraction/base_extraction_engine.py | 3 +- .../image/base_image_generation_engine.py | 8 +++-- .../engines/image_query/image_query_engine.py | 6 ++-- .../response/prompt_response_rag_module.py | 3 +- .../vector_store_retrieval_rag_module.py | 3 +- .../engines/summary/prompt_summary_engine.py | 6 ++-- .../structure/base_conversation_memory.py | 7 +++-- .../structure/summary_conversation_memory.py | 19 ++---------- .../task/storage/text_artifact_storage.py | 5 +-- griptape/structures/structure.py | 31 +++++++------------ griptape/tasks/audio_transcription_task.py | 22 ++----------- griptape/tasks/csv_extraction_task.py | 17 ++-------- griptape/tasks/extraction_task.py | 6 +--- griptape/tasks/image_query_task.py | 17 ++-------- .../tasks/inpainting_image_generation_task.py | 22 ++----------- griptape/tasks/json_extraction_task.py | 17 ++-------- .../outpainting_image_generation_task.py | 23 ++------------ .../tasks/prompt_image_generation_task.py | 22 ++----------- griptape/tasks/prompt_task.py | 12 ++----- griptape/tasks/rag_task.py | 23 ++------------ griptape/tasks/text_summary_task.py | 19 ++---------- griptape/tasks/text_to_speech_task.py | 19 ++---------- .../tasks/variation_image_generation_task.py | 22 ++----------- tests/unit/conftest.py | 16 ++++++++++ tests/unit/events/test_event_listener.py | 5 +-- .../test_summary_conversation_memory.py | 3 +- .../tasks/test_audio_transcription_task.py | 3 +- tests/unit/tasks/test_csv_extraction_task.py | 9 ++---- tests/unit/tasks/test_image_query_task.py | 9 +----- .../test_inpainting_image_generation_task.py | 9 +----- tests/unit/tasks/test_json_extraction_task.py | 13 ++------ .../test_outpainting_image_generation_task.py | 9 +----- .../test_prompt_image_generation_task.py | 11 +------ tests/unit/tasks/test_prompt_task.py | 11 +------ tests/unit/tasks/test_text_summary_task.py | 11 +------ tests/unit/tasks/test_text_to_speech_task.py | 3 +- tests/unit/tasks/test_toolkit_task.py | 9 ++++-- .../test_variation_image_generation_task.py | 9 +----- 45 files changed, 144 insertions(+), 356 deletions(-) create mode 100644 griptape/config/config.py diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 541eb0db0..4b0f8eb28 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -9,6 +9,7 @@ from .anthropic_structure_config import AnthropicStructureConfig from .google_structure_config import GoogleStructureConfig from .cohere_structure_config import CohereStructureConfig +from .config import Config __all__ = [ @@ -21,4 +22,5 @@ "AnthropicStructureConfig", "GoogleStructureConfig", "CohereStructureConfig", + "Config", ] diff --git a/griptape/config/config.py b/griptape/config/config.py new file mode 100644 index 000000000..e3017f8b6 --- /dev/null +++ b/griptape/config/config.py @@ -0,0 +1,3 @@ +from .openai_structure_config import OpenAiStructureConfig + +Config = OpenAiStructureConfig() diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index e52174c28..44f214d7c 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -5,12 +5,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils import import_optional_dependency if TYPE_CHECKING: import boto3 + from griptape.memory.structure import BaseConversationMemory + @define class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): @@ -38,6 +39,8 @@ def store(self, memory: BaseConversationMemory) -> None: ) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index 8d6399e13..f7b6e7d6e 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -2,12 +2,14 @@ import os from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory + +if TYPE_CHECKING: + from griptape.memory.structure import BaseConversationMemory @define @@ -18,6 +20,8 @@ def store(self, memory: BaseConversationMemory) -> None: Path(self.file_path).write_text(memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + if not os.path.exists(self.file_path): return None memory = BaseConversationMemory.from_json(Path(self.file_path).read_text()) diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 2ba3737e8..9afc2f204 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -6,12 +6,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils.import_utils import import_optional_dependency if TYPE_CHECKING: from redis import Redis + from griptape.memory.structure import BaseConversationMemory + @define class RedisConversationMemoryDriver(BaseConversationMemoryDriver): @@ -54,6 +55,8 @@ def store(self, memory: BaseConversationMemory) -> None: self.client.hset(self.index, self.conversation_id, memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + key = self.index memory_json = self.client.hget(key, self.conversation_id) if memory_json: diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 3631b2d17..a3769842d 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,12 +1,15 @@ -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.config import Config from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True) + audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=Factory(lambda: Config.audio_transcription_driver), kw_only=True + ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: return self.audio_transcription_driver.try_run(audio) diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index af5d5a494..361ecc127 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -11,7 +13,9 @@ @define class TextToSpeechEngine: - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: Config.text_to_speech_driver), kw_only=True + ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: return self.text_to_speech_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index f263ee0aa..3ff6a96e3 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,6 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -17,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 47a853871..eabf38be3 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -13,7 +15,9 @@ @define class BaseImageGenerationEngine(ABC): - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True) + image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=Factory(lambda: Config.image_generation_driver) + ) @abstractmethod def run(self, prompts: list[str], *args, rulesets: Optional[list[Ruleset]], **kwargs) -> ImageArtifact: ... diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index d0a1e99d4..ed6a64ee3 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import Config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -11,7 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(kw_only=True) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.image_query_driver), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 0b7cbd953..2e7b486b6 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact +from griptape.config import Config from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 @@ -16,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field() + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 0a07b4c50..b0deca67d 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.config import Config from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -17,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index c5d8e695d..d06ebaa2f 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -6,8 +6,8 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import Config from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index c3d3c501e..8794288c8 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field from griptape.common import PromptStack +from griptape.config import Config from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -16,7 +17,9 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field(default=None, kw_only=True) + driver: Optional[BaseConversationMemoryDriver] = field( + default=Factory(lambda: Config.conversation_memory_driver), kw_only=True + ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) autoload: bool = field(default=True, kw_only=True) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index f29bbb767..807775d63 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -5,8 +5,8 @@ from attrs import Factory, define, field -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import Config from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - _prompt_driver: BasePromptDriver = field(kw_only=True, default=None, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) @@ -27,19 +27,6 @@ class SummaryConversationMemory(ConversationMemory): kw_only=True, ) - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set.") - return self._prompt_driver - - @prompt_driver.setter - def prompt_driver(self, value: BasePromptDriver) -> None: - self._prompt_driver = value - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e66c5aba..8a918c5f2 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.config import Config from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage @@ -15,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..9f1fa9a2b 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -11,7 +11,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig +from griptape.config import BaseStructureConfig, Config from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -59,10 +59,7 @@ class Structure(ABC): custom_logger: Optional[Logger] = field(default=None, kw_only=True) logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( - default=Factory( - lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), - takes_self=True, - ), + default=Factory(lambda: ConversationMemory()), kw_only=True, ) rag_engine: RagEngine = field(default=Factory(lambda self: self.default_rag_engine, takes_self=True), kw_only=True) @@ -154,8 +151,6 @@ def finished_tasks(self) -> list[BaseTask]: @property def default_config(self) -> BaseStructureConfig: if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: - config = StructureConfig() - prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver @@ -165,26 +160,24 @@ def default_config(self) -> BaseStructureConfig: vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - config.prompt_driver = prompt_driver - config.vector_store_driver = vector_store_driver - config.embedding_driver = embedding_driver - else: - config = OpenAiStructureConfig() + Config.prompt_driver = prompt_driver + Config.vector_store_driver = vector_store_driver + Config.embedding_driver = embedding_driver - return config + return Config @property def default_rag_engine(self) -> RagEngine: return RagEngine( retrieval_stage=RetrievalRagStage( - retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], + retrieval_modules=[VectorStoreRetrievalRagModule()], ), response_stage=ResponseRagStage( before_response_modules=[ RulesetsBeforeResponseRagModule(rulesets=self.rulesets), MetadataBeforeResponseRagModule(), ], - response_module=PromptResponseRagModule(prompt_driver=self.config.prompt_driver), + response_module=PromptResponseRagModule(), ), ) @@ -195,10 +188,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=self.config.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=self.config.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=self.config.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=self.config.prompt_driver), + vector_store_driver=Config.vector_store_driver, + summary_engine=PromptSummaryEngine(prompt_driver=Config.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py index 3a4b17b9e..3d83cf7e7 100644 --- a/griptape/tasks/audio_transcription_task.py +++ b/griptape/tasks/audio_transcription_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import AudioTranscriptionEngine from griptape.tasks.base_audio_input_task import BaseAudioInputTask @@ -13,26 +13,10 @@ @define class AudioTranscriptionTask(BaseAudioInputTask): - _audio_transcription_engine: AudioTranscriptionEngine = field( - default=None, + audio_transcription_engine: AudioTranscriptionEngine = field( + default=Factory(lambda: AudioTranscriptionEngine()), kw_only=True, - alias="audio_transcription_engine", ) - @property - def audio_transcription_engine(self) -> AudioTranscriptionEngine: - if self._audio_transcription_engine is None: - if self.structure is not None: - self._audio_transcription_engine = AudioTranscriptionEngine( - audio_transcription_driver=self.structure.config.audio_transcription_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._audio_transcription_engine - - @audio_transcription_engine.setter - def audio_transcription_engine(self, value: AudioTranscriptionEngine) -> None: - self._audio_transcription_engine = value - def run(self) -> TextArtifact: return self.audio_transcription_engine.run(self.input) diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index 538596dfe..c252893de 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import CsvExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class CsvExtractionTask(ExtractionTask): - _extraction_engine: CsvExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> CsvExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = CsvExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: CsvExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: CsvExtractionEngine = field(default=Factory(lambda: CsvExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index d8f492693..a1c18eff0 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -13,12 +13,8 @@ @define class ExtractionTask(BaseTextInputTask): - _extraction_engine: BaseExtractionEngine = field(kw_only=True, default=None, alias="extraction_engine") + extraction_engine: BaseExtractionEngine = field(kw_only=True) args: dict = field(kw_only=True) - @property - def extraction_engine(self) -> BaseExtractionEngine: - return self._extraction_engine - def run(self) -> ListArtifact | ErrorArtifact: return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index ea1b53739..1c77bbc0a 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import ImageQueryEngine @@ -24,7 +24,7 @@ class ImageQueryTask(BaseTask): image_query_engine: The engine used to execute the query. """ - _image_query_engine: ImageQueryEngine = field(default=None, kw_only=True, alias="image_query_engine") + image_query_engine: ImageQueryEngine = field(default=Factory(lambda: ImageQueryEngine()), kw_only=True) _input: ( tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] @@ -62,19 +62,6 @@ def input( ) -> None: self._input = value - @property - def image_query_engine(self) -> ImageQueryEngine: - if self._image_query_engine is None: - if self.structure is not None: - self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) - else: - raise ValueError("Image Query Engine is not set.") - return self._image_query_engine - - @image_query_engine.setter - def image_query_engine(self, value: ImageQueryEngine) -> None: - self._image_query_engine = value - def run(self) -> TextArtifact: query = self.input.value[0] diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 2096c60e4..ec8014672 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import InpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: InpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: InpaintingImageGenerationEngine = field( + default=Factory(lambda: InpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,21 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> InpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = InpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: InpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py index ce51b316f..94db187da 100644 --- a/griptape/tasks/json_extraction_task.py +++ b/griptape/tasks/json_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import JsonExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class JsonExtractionTask(ExtractionTask): - _extraction_engine: JsonExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> JsonExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = JsonExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: JsonExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: JsonExtractionEngine = field(default=Factory(lambda: JsonExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index a23fafd0f..bee3293a1 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import OutpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: OutpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: OutpaintingImageGenerationEngine = field( + default=Factory(lambda: OutpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,22 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> OutpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = OutpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: OutpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 66cffab3e..efc3faf2d 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.engines import PromptImageGenerationEngine @@ -30,10 +30,9 @@ class PromptImageGenerationTask(BaseImageGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _image_generation_engine: PromptImageGenerationEngine = field( - default=None, + image_generation_engine: PromptImageGenerationEngine = field( + default=Factory(lambda: PromptImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) @property @@ -49,21 +48,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def image_generation_engine(self) -> PromptImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = PromptImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: PromptImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: image_artifact = self.image_generation_engine.run( prompts=[self.input.to_text()], diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 386ebe239..9f698787f 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -6,6 +6,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack +from griptape.config import Config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -16,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - _prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, @@ -56,15 +57,6 @@ def prompt_stack(self) -> PromptStack: return stack - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set") - return self._prompt_driver - def default_system_template_generator(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index 3f88f34d1..97b295209 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -1,32 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.engines.rag import RagEngine from griptape.tasks import BaseTextInputTask -if TYPE_CHECKING: - from griptape.engines.rag import RagEngine - @define class RagTask(BaseTextInputTask): - _rag_engine: RagEngine = field(kw_only=True, default=None, alias="rag_engine") - - @property - def rag_engine(self) -> RagEngine: - if self._rag_engine is None: - if self.structure is not None: - self._rag_engine = self.structure.rag_engine - else: - raise ValueError("rag_engine is not set.") - return self._rag_engine - - @rag_engine.setter - def rag_engine(self, value: RagEngine) -> None: - self._rag_engine = value + rag_engine: RagEngine = field(kw_only=True, default=Factory(lambda: RagEngine())) def run(self) -> BaseArtifact: result = self.rag_engine.process_query(self.input.to_text()).output diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 5bd1b547e..dc1a7b8be 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import PromptSummaryEngine @@ -14,20 +14,7 @@ @define class TextSummaryTask(BaseTextInputTask): - _summary_engine: Optional[BaseSummaryEngine] = field(default=None, alias="summary_engine") - - @property - def summary_engine(self) -> Optional[BaseSummaryEngine]: - if self._summary_engine is None: - if self.structure is not None: - self._summary_engine = PromptSummaryEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Summary Engine is not set.") - return self._summary_engine - - @summary_engine.setter - def summary_engine(self, value: BaseSummaryEngine) -> None: - self._summary_engine = value + summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) def run(self) -> TextArtifact: return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets)) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 3ca503dfe..680a67603 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import TextToSpeechEngine @@ -19,7 +19,7 @@ class TextToSpeechTask(BaseAudioGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _text_to_speech_engine: TextToSpeechEngine = field(default=None, kw_only=True, alias="text_to_speech_engine") + text_to_speech_engine: TextToSpeechEngine = field(default=Factory(lambda: TextToSpeechEngine()), kw_only=True) @property def input(self) -> TextArtifact: @@ -34,21 +34,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def text_to_speech_engine(self) -> TextToSpeechEngine: - if self._text_to_speech_engine is None: - if self.structure is not None: - self._text_to_speech_engine = TextToSpeechEngine( - text_to_speech_driver=self.structure.config.text_to_speech_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._text_to_speech_engine - - @text_to_speech_engine.setter - def text_to_speech_engine(self, value: TextToSpeechEngine) -> None: - self._text_to_speech_engine = value - def run(self) -> AudioArtifact: audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index df4579efa..6295b1af6 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import VariationImageGenerationEngine @@ -28,10 +28,9 @@ class VariationImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: VariationImageGenerationEngine = field( - default=None, + image_generation_engine: VariationImageGenerationEngine = field( + default=Factory(lambda: VariationImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact = field( default=None, @@ -57,21 +56,6 @@ def input(self) -> ListArtifact: def input(self, value: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact]) -> None: self._input = value - @property - def image_generation_engine(self) -> VariationImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = VariationImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: VariationImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7d2f8203d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,6 +1,8 @@ import pytest +from griptape.config import Config from griptape.events import EventBus +from tests.mocks.mock_structure_config import MockStructureConfig @pytest.fixture(autouse=True) @@ -10,3 +12,17 @@ def event_bus(): yield EventBus EventBus.event_listeners = [] + + +@pytest.fixture(autouse=True) +def mock_config(): + mock_structure_config = MockStructureConfig() + Config.prompt_driver = mock_structure_config.prompt_driver + Config.image_generation_driver = mock_structure_config.image_generation_driver + Config.image_query_driver = mock_structure_config.image_query_driver + Config.embedding_driver = mock_structure_config.embedding_driver + Config.vector_store_driver = mock_structure_config.vector_store_driver + Config.text_to_speech_driver = mock_structure_config.text_to_speech_driver + Config.audio_transcription_driver = mock_structure_config.audio_transcription_driver + + return Config diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..92c5a3653 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -25,8 +25,9 @@ class TestEventListener: @pytest.fixture() - def pipeline(self): + def pipeline(self, mock_config): task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) + mock_config.prompt_driver = MockPromptDriver(stream=True) pipeline = Pipeline(prompt_driver=MockPromptDriver(stream=True)) pipeline.add_task(task) @@ -34,7 +35,7 @@ def pipeline(self): task.add_subtask(ActionsSubtask("foo")) return pipeline - def test_untyped_listeners(self, pipeline): + def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 4396c7b23..579be214e 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -5,7 +5,6 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestSummaryConversationMemory: @@ -85,7 +84,7 @@ def test_from_json(self): def test_config_prompt_driver(self): memory = SummaryConversationMemory() - pipeline = Pipeline(conversation_memory=memory, config=MockStructureConfig()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 734e111cf..4cc860c32 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -7,7 +7,6 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestAudioTranscriptionTask: @@ -34,7 +33,7 @@ def callable_input(task: BaseTask) -> AudioArtifact: def test_config_audio_transcription_engine(self, audio_artifact): task = AudioTranscriptionTask(audio_artifact) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py index 7d37c3897..ec8f70b23 100644 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ b/tests/unit/tasks/test_csv_extraction_task.py @@ -4,7 +4,6 @@ from griptape.structures import Agent from griptape.tasks import CsvExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestCsvExtractionTask: @@ -13,7 +12,7 @@ def task(self): return CsvExtractionTask(args={"column_names": ["test1"]}) def test_run(self, task): - agent = Agent(config=MockStructureConfig()) + agent = Agent() agent.add_task(task) @@ -23,11 +22,7 @@ def test_run(self, task): assert result.value[0].value == {"test1": "mock output"} def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, CsvExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index 447faa01c..01c116772 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, ImageQueryTask from tests.mocks.mock_image_query_driver import MockImageQueryDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestImageQueryTask: @@ -61,17 +60,11 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_config_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_query_engine, ImageQueryEngine) assert isinstance(task.image_query_engine.image_query_driver, MockImageQueryDriver) - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - - with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine # noqa: B018 - def test_run(self, image_query_engine, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) task.run() diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 61c437bb7..5c4507d49 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, InpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestInpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, InpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index ba7d1ce30..0189e6679 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -5,7 +5,6 @@ from griptape.structures import Agent from griptape.tasks import JsonExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestJsonExtractionTask: @@ -13,11 +12,9 @@ class TestJsonExtractionTask: def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) - def test_run(self, task): - mock_config = MockStructureConfig() - assert isinstance(mock_config.prompt_driver, MockPromptDriver) + def test_run(self, task, mock_config): mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - agent = Agent(config=mock_config) + agent = Agent() agent.add_task(task) @@ -28,11 +25,7 @@ def test_run(self, task): assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, JsonExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index 593451120..ba5e52a82 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, OutpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestOutpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, OutpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_image_generation_task.py b/tests/unit/tasks/test_prompt_image_generation_task.py index 1c4b639fb..3ad0302f2 100644 --- a/tests/unit/tasks/test_prompt_image_generation_task.py +++ b/tests/unit/tasks/test_prompt_image_generation_task.py @@ -1,13 +1,10 @@ from unittest.mock import Mock -import pytest - from griptape.artifacts import TextArtifact from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, PromptImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptImageGenerationTask: @@ -28,13 +25,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_image_generation_engine_engine(self): task = PromptImageGenerationTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, PromptImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self): - task = PromptImageGenerationTask("foo bar") - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 083ea6da5..4a618e0d1 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,3 @@ -import pytest - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact @@ -9,7 +7,6 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptTask: @@ -30,16 +27,10 @@ def test_to_text(self): def test_config_prompt_driver(self): task = PromptTask("test") - Pipeline(config=MockStructureConfig()).add_task(task) + Pipeline().add_task(task) assert isinstance(task.prompt_driver, MockPromptDriver) - def test_missing_prompt_driver(self): - task = PromptTask("test") - - with pytest.raises(ValueError): - task.prompt_driver # noqa: B018 - def test_input(self): # Str task = PromptTask("test") diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index bb08f9d31..438d2bae4 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -1,10 +1,7 @@ -import pytest - from griptape.engines import PromptSummaryEngine from griptape.structures import Agent from griptape.tasks import TextSummaryTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextSummaryTask: @@ -26,13 +23,7 @@ def test_context_propagation(self): def test_config_summary_engine(self): task = TextSummaryTask("test") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.summary_engine, PromptSummaryEngine) assert isinstance(task.summary_engine.prompt_driver, MockPromptDriver) - - def test_missing_summary_engine(self): - task = TextSummaryTask("test") - - with pytest.raises(ValueError): - task.summary_engine # noqa: B018 diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index bf1f19d5a..3c629c69d 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -5,7 +5,6 @@ from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextToSpeechTask: @@ -26,7 +25,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_text_to_speech_engine(self): task = TextToSpeechTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index cd5dd21f8..a47e4687b 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -170,8 +170,9 @@ def test_init(self): except ValueError: assert True - def test_run(self): + def test_run(self, mock_config): output = """Answer: done""" + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) @@ -184,8 +185,9 @@ def test_run(self): assert len(task.subtasks) == 1 assert result.output_task.output.to_text() == "done" - def test_run_max_subtasks(self): + def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) @@ -197,8 +199,9 @@ def test_run_max_subtasks(self): assert len(task.subtasks) == 3 assert isinstance(task.output, ErrorArtifact) - def test_run_invalid_react_prompt(self): + def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" + mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index a910fb8e0..f6afbf03e 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, VariationImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestVariationImageGenerationTask: @@ -48,13 +47,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, VariationImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self, text_artifact, image_artifact): - task = VariationImageGenerationTask((text_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 From 44f9ebd3b3f9195bf4d850b2fa551b289708dc6d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 11:20:36 -0700 Subject: [PATCH 03/40] WIP Event listners --- griptape/config/base_structure_config.py | 2 +- griptape/config/structure_config.py | 22 ++++---- griptape/drivers/prompt/base_prompt_driver.py | 2 + .../structure/base_conversation_memory.py | 2 +- griptape/structures/agent.py | 1 + griptape/structures/structure.py | 50 +------------------ griptape/utils/stream.py | 6 ++- .../test_azure_openai_structure_config.py | 15 +----- tests/unit/config/test_structure_config.py | 25 +--------- tests/unit/utils/test_stream.py | 7 +-- 10 files changed, 30 insertions(+), 102 deletions(-) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index c2aa82d7e..84743c4da 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -22,7 +22,7 @@ @define -class BaseStructureConfig(BaseConfig, ABC): +class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index ef95012ce..d68b6e2e2 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -1,19 +1,11 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field from griptape.config import BaseStructureConfig from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, DummyAudioTranscriptionDriver, DummyEmbeddingDriver, DummyImageGenerationDriver, @@ -23,6 +15,18 @@ DummyVectorStoreDriver, ) +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + @define class StructureConfig(BaseStructureConfig): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..b6c28560b 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -113,6 +113,8 @@ def __process_run(self, prompt_stack: PromptStack) -> Message: return result def __process_stream(self, prompt_stack: PromptStack) -> Message: + from griptape.config import Config + delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 8794288c8..fb1cfdd8b 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = self.structure.config.prompt_driver + prompt_driver = Config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index b133a7b6b..31e0a424f 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -32,6 +32,7 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() + if len(self.tasks) == 0: if self.tools: task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 9f1fa9a2b..73b5e617a 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -11,14 +11,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import BaseStructureConfig, Config -from griptape.drivers import ( - BaseEmbeddingDriver, - BasePromptDriver, - LocalVectorStoreDriver, - OpenAiChatPromptDriver, - OpenAiEmbeddingDriver, -) +from griptape.config import Config from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -33,7 +26,6 @@ from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.utils import deprecation_warn if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory @@ -46,13 +38,6 @@ class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - stream: Optional[bool] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - config: BaseStructureConfig = field( - default=Factory(lambda self: self.default_config, takes_self=True), - kw_only=True, - ) rulesets: list[Ruleset] = field(factory=list, kw_only=True) rules: list[Rule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) @@ -99,21 +84,6 @@ def __attrs_post_init__(self) -> None: def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: return self.add_tasks(*other) if isinstance(other, list) else self + [other] - @prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_prompt_driver(self, attribute: Attribute, value: BasePromptDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver` instead.") - - @embedding_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_embedding_driver(self, attribute: Attribute, value: BaseEmbeddingDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.embedding_driver` instead.") - - @stream.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_stream(self, attribute: Attribute, value: bool) -> None: # noqa: FBT001 - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver.stream` instead.") - @property def execution_args(self) -> tuple: return self._execution_args @@ -148,24 +118,6 @@ def output(self) -> Optional[BaseArtifact]: def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] - @property - def default_config(self) -> BaseStructureConfig: - if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: - prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver - - embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver - - if self.stream is not None: - prompt_driver.stream = self.stream - - vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - - Config.prompt_driver = prompt_driver - Config.vector_store_driver = vector_store_driver - Config.embedding_driver = embedding_driver - - return Config - @property def default_rag_engine(self) -> RagEngine: return RagEngine( diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..7b5381202 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,7 +34,9 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - if not structure.config.prompt_driver.stream: + from griptape.config import Config + + if not Config.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) @@ -54,6 +56,8 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args) -> None: + from griptape.config import Config + def event_handler(event: BaseEvent) -> None: self._event_queue.put(event) diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index dcdc3a1dc..abeb6b878 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -8,7 +8,7 @@ class TestAzureOpenAiStructureConfig: def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") - @pytest.fixture() + @pytest.fixture def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", @@ -85,16 +85,3 @@ def test_to_dict(self, config): "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } - - def test_from_dict(self, config: AzureOpenAiStructureConfig): - assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - # override values in the dict config - # serialize and deserialize the config - new_config = config.merge_config( - { - "prompt_driver": {"azure_deployment": "new-test-gpt-4"}, - "embedding_driver": {"model": "new-text-embedding-3-small"}, - } - ).to_dict() - assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 96a68628f..c7f02034f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -4,7 +4,7 @@ class TestStructureConfig: - @pytest.fixture() + @pytest.fixture def config(self): return StructureConfig() @@ -33,29 +33,6 @@ def test_to_dict(self, config): def test_from_dict(self, config): assert StructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - def test_unchanged_merge_config(self, config): - assert ( - config.merge_config( - { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - } - ).to_dict() - == config.to_dict() - ) - - def test_changed_merge_config(self, config): - config = config.merge_config( - {"prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}} - ) - - assert config.prompt_driver.temperature == 0.1 - def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index da6695139..555daa4fd 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,18 +2,19 @@ import pytest +from griptape.config import Config from griptape.structures import Agent from griptape.utils import Stream -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - return Agent(prompt_driver=MockPromptDriver(stream=request.param, max_attempts=0)) + Config.prompt_driver.stream = request.param + return Agent() def test_init(self, agent): - if agent.prompt_driver.stream: + if Config.prompt_driver.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent From 6b76237978b756650e7c1ab5ce5c7cfc97968f85 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:12:25 -0700 Subject: [PATCH 04/40] Fix tests --- griptape/utils/chat.py | 8 +- tests/mocks/mock_structure_config.py | 5 + .../test_azure_openai_structure_config.py | 2 +- tests/unit/config/test_structure_config.py | 2 +- .../test_base_audio_transcription_driver.py | 2 +- ...est_dynamodb_conversation_memory_driver.py | 13 +-- .../test_local_conversation_memory_driver.py | 10 +- ...est_open_telemetry_observability_driver.py | 3 +- .../drivers/prompt/test_base_prompt_driver.py | 24 ++-- .../test_local_structure_run_driver.py | 7 +- .../extraction/test_csv_extraction_engine.py | 3 +- ...est_footnote_prompt_response_rag_module.py | 3 +- .../test_prompt_response_rag_module.py | 3 +- tests/unit/engines/rag/test_rag_engine.py | 23 +--- .../summary/test_prompt_summary_engine.py | 6 +- tests/unit/events/test_event_listener.py | 6 +- .../test_finish_actions_subtask_event.py | 3 +- tests/unit/events/test_finish_task_event.py | 3 +- .../test_start_actions_subtask_event.py | 3 +- tests/unit/events/test_start_task_event.py | 3 +- .../structure/test_conversation_memory.py | 16 ++- .../test_summary_conversation_memory.py | 8 +- tests/unit/structures/test_agent.py | 68 ++---------- tests/unit/structures/test_pipeline.py | 63 +++-------- tests/unit/structures/test_workflow.py | 104 ++++++------------ .../tasks/test_audio_transcription_task.py | 3 +- .../tasks/test_base_multi_text_input_task.py | 3 +- tests/unit/tasks/test_base_task.py | 5 +- tests/unit/tasks/test_base_text_input_task.py | 3 +- tests/unit/tasks/test_code_execution_task.py | 3 +- tests/unit/tasks/test_extraction_task.py | 5 +- tests/unit/tasks/test_prompt_task.py | 2 +- tests/unit/tasks/test_rag_task.py | 7 +- tests/unit/tasks/test_structure_run_task.py | 8 +- tests/unit/tasks/test_text_summary_task.py | 2 +- tests/unit/tasks/test_text_to_speech_task.py | 3 +- tests/unit/tasks/test_tool_task.py | 10 +- tests/unit/tasks/test_toolkit_task.py | 7 +- tests/unit/tools/test_structure_run_client.py | 4 +- tests/unit/utils/test_chat.py | 3 +- tests/unit/utils/test_conversation.py | 10 +- tests/unit/utils/test_file_utils.py | 5 +- tests/unit/utils/test_structure_visualizer.py | 5 +- tests/utils/defaults.py | 6 +- tests/utils/test_reference_utils.py | 3 +- 45 files changed, 164 insertions(+), 324 deletions(-) diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index e98eeaa4d..a8bdc9b13 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,12 +25,16 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - if self.structure.config.prompt_driver.stream: + from griptape.config import Config + + if Config.prompt_driver.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 def start(self) -> None: + from griptape.config import Config + if self.intro_text: self.output_fn(self.intro_text) while True: @@ -40,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if self.structure.config.prompt_driver.stream: + if Config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py index 3f95288f4..0b374449d 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_structure_config.py @@ -1,6 +1,7 @@ from attrs import Factory, define, field from griptape.config import StructureConfig +from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -21,3 +22,7 @@ class MockStructureConfig(StructureConfig): embedding_driver: MockEmbeddingDriver = field( default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) + vector_store_driver: LocalVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + metadata={"serializable": True}, + ) diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index abeb6b878..810cb41a1 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -8,7 +8,7 @@ class TestAzureOpenAiStructureConfig: def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") - @pytest.fixture + @pytest.fixture() def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index c7f02034f..cce97647e 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -4,7 +4,7 @@ class TestStructureConfig: - @pytest.fixture + @pytest.fixture() def config(self): return StructureConfig() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..29aecfdf9 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -12,7 +12,7 @@ class TestBaseAudioTranscriptionDriver: def driver(self): return MockAudioTranscriptionDriver() - def test_run_publish_events(self, driver): + def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() EventBus.add_event_listener(EventListener(handler=mock_handler)) diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index 8e700d0a5..f1a5df1be 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -6,7 +6,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.aws import mock_aws_credentials @@ -40,7 +39,6 @@ def test_store(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -49,7 +47,7 @@ def test_store(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -65,7 +63,6 @@ def test_store_with_sort_key(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -76,7 +73,7 @@ def test_store_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -89,7 +86,6 @@ def test_store_with_sort_key(self): assert "Item" in response def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -98,7 +94,7 @@ def test_load(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -113,7 +109,6 @@ def test_load(self): assert new_memory.runs[0].output.value == "mock output" def test_load_with_sort_key(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -124,7 +119,7 @@ def test_load_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index e1a383ab9..dff66d0fc 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -7,7 +7,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestLocalConversationMemoryDriver: @@ -22,10 +21,9 @@ def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) def test_store(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -41,10 +39,9 @@ def test_store(self): assert True def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -60,10 +57,9 @@ def test_load(self): assert new_memory.max_runs == 5 def test_autoload(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py index 4f7ce50f0..758505b26 100644 --- a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py +++ b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py @@ -8,7 +8,6 @@ from griptape.drivers import OpenTelemetryObservabilityDriver from griptape.observability.observability import Observability from griptape.structures.agent import Agent -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.expected_spans import ExpectedSpan, ExpectedSpans @@ -170,7 +169,7 @@ def test_observability_agent(self, driver, mock_span_exporter): ) with Observability(observability_driver=driver): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.run("Hi") assert mock_span_exporter.export.call_count == 1 diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..d95e7a5a7 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -10,17 +10,17 @@ class TestBasePromptDriver: - def test_run_via_pipeline_retries_success(self): - driver = MockPromptDriver(max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_success(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(max_attempts=2) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) assert isinstance(pipeline.run().output_task.output, TextArtifact) - def test_run_via_pipeline_retries_failure(self): - driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_failure(self, mock_config): + mock_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -46,9 +46,9 @@ def test_run_with_stream(self): assert isinstance(result, Message) assert result.value == "mock output" - def test_run_with_tools(self): - driver = MockPromptDriver(max_attempts=1, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) @@ -56,9 +56,9 @@ def test_run_with_tools(self): assert isinstance(output, TextArtifact) assert output.value == "mock output" - def test_run_with_tools_and_stream(self): - driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools_and_stream(self, mock_config): + mock_config.driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 316f7bf71..318a41aa2 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -9,7 +9,7 @@ class TestLocalStructureRunDriver: def test_run(self): pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent(prompt_driver=MockPromptDriver())) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent()) task = StructureRunTask(driver=driver) @@ -17,10 +17,11 @@ def test_run(self): assert task.run().to_text() == "mock output" - def test_run_with_env(self): + def test_run_with_env(self, mock_config): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["KEY"])) + mock_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index f69d8a0ba..d84fc7cdd 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,13 +1,12 @@ import pytest from griptape.engines import CsvExtractionEngine -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestCsvExtractionEngine: @pytest.fixture() def engine(self): - return CsvExtractionEngine(prompt_driver=MockPromptDriver()) + return CsvExtractionEngine() def test_extract(self, engine): result = engine.extract("foo", column_names=["test1"]) diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 385cf0c04..f7819c6d7 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -4,13 +4,12 @@ from griptape.common import Reference from griptape.engines.rag import RagContext from griptape.engines.rag.modules import FootnotePromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFootnotePromptResponseRagModule: @pytest.fixture() def module(self): - return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) + return FootnotePromptResponseRagModule() def test_run(self, module): assert module.run(RagContext(query="test")).output.value == "mock output" diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index 2f8a912e2..0e3526a52 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -3,13 +3,12 @@ from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestPromptResponseRagModule: @pytest.fixture() def module(self): - return PromptResponseRagModule(prompt_driver=MockPromptDriver()) + return PromptResponseRagModule() def test_run(self, module): assert module.run(RagContext(query="test")).output.value == "mock output" diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index c3d728bb3..40ab4af4d 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -1,36 +1,25 @@ import pytest -from griptape.drivers import LocalVectorStoreDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, VectorStoreRetrievalRagModule from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestRagEngine: @pytest.fixture() def engine(self): return RagEngine( - retrieval_stage=RetrievalRagStage( - retrieval_modules=[ - VectorStoreRetrievalRagModule( - vector_store_driver=LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - ) - ] - ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())), + retrieval_stage=RetrievalRagStage(retrieval_modules=[VectorStoreRetrievalRagModule()]), + response_stage=ResponseRagStage(response_module=PromptResponseRagModule()), ) def test_module_name_uniqueness(self): - vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - with pytest.raises(ValueError): RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test"), + VectorStoreRetrievalRagModule(name="test"), ] ) ) @@ -38,8 +27,8 @@ def test_module_name_uniqueness(self): assert RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test1", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test2", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test1"), + VectorStoreRetrievalRagModule(name="test2"), ] ) ) diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 4d9c65e03..138444ae3 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -12,7 +12,7 @@ class TestPromptSummaryEngine: @pytest.fixture() def engine(self): - return PromptSummaryEngine(prompt_driver=MockPromptDriver()) + return PromptSummaryEngine() def test_summarize_text(self, engine): assert engine.summarize_text("foobar") == "mock output" @@ -24,10 +24,10 @@ def test_summarize_artifacts(self, engine): def test_max_token_multiplier_invalid(self, engine): with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=0) + PromptSummaryEngine(max_token_multiplier=0) with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) + PromptSummaryEngine(max_token_multiplier=10000) def test_chunked_summary(self, engine): def smaller_input(prompt_stack: PromptStack): diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 92c5a3653..ed978db78 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,10 +26,10 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) mock_config.prompt_driver = MockPromptDriver(stream=True) + task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) - pipeline = Pipeline(prompt_driver=MockPromptDriver(stream=True)) + pipeline = Pipeline() pipeline.add_task(task) task.add_subtask(ActionsSubtask("foo")) @@ -49,7 +49,7 @@ def test_untyped_listeners(self, pipeline, mock_config): assert event_handler_1.call_count == 9 assert event_handler_2.call_count == 9 - def test_typed_listeners(self, pipeline): + def test_typed_listeners(self, pipeline, mock_config): start_prompt_event_handler = Mock() finish_prompt_event_handler = Mock() start_task_event_handler = Mock() diff --git a/tests/unit/events/test_finish_actions_subtask_event.py b/tests/unit/events/test_finish_actions_subtask_event.py index 5e2a0807a..5fc35755b 100644 --- a/tests/unit/events/test_finish_actions_subtask_event.py +++ b/tests/unit/events/test_finish_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import FinishActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def finish_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_finish_task_event.py b/tests/unit/events/test_finish_task_event.py index df1d6d42a..2568752bb 100644 --- a/tests/unit/events/test_finish_task_event.py +++ b/tests/unit/events/test_finish_task_event.py @@ -3,14 +3,13 @@ from griptape.events import FinishTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFinishTaskEvent: @pytest.fixture() def finish_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/events/test_start_actions_subtask_event.py b/tests/unit/events/test_start_actions_subtask_event.py index 8b628057c..b7236911f 100644 --- a/tests/unit/events/test_start_actions_subtask_event.py +++ b/tests/unit/events/test_start_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import StartActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def start_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_start_task_event.py b/tests/unit/events/test_start_task_event.py index ea027f147..111d35934 100644 --- a/tests/unit/events/test_start_task_event.py +++ b/tests/unit/events/test_start_task_event.py @@ -3,14 +3,13 @@ from griptape.events import StartTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStartTaskEvent: @pytest.fixture() def start_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 2ffd7b8cb..77cebf193 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -60,7 +60,7 @@ def test_from_json(self): def test_buffering(self): memory = ConversationMemory(max_runs=2) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask()) @@ -75,7 +75,7 @@ def test_buffering(self): assert pipeline.conversation_memory.runs[1].input.value == "run5" def test_add_to_prompt_stack_autopruing_disabled(self): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() memory = ConversationMemory( autoprune=False, runs=[ @@ -94,9 +94,11 @@ def test_add_to_prompt_stack_autopruing_disabled(self): assert len(prompt_stack.messages) == 12 - def test_add_to_prompt_stack_autopruning_enabled(self): + def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) + + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -117,7 +119,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): assert len(prompt_stack.messages) == 3 # No memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -140,7 +143,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) + mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 579be214e..42246e349 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -9,9 +9,9 @@ class TestSummaryConversationMemory: def test_unsummarized_subtasks(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) @@ -23,9 +23,9 @@ def test_unsummarized_subtasks(self): assert len(memory.unsummarized_runs()) == 1 def test_after_run(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index a09ad0f9a..15e1399b6 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -1,23 +1,17 @@ import pytest -from griptape.engines import PromptSummaryEngine from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Agent from griptape.tasks import BaseTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool class TestAgent: def test_init(self): - driver = MockPromptDriver() - agent = Agent(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + agent = Agent(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert agent.prompt_driver is driver assert isinstance(agent.task, PromptTask) assert isinstance(agent.task, PromptTask) assert agent.rulesets[0].name == "TestRuleset" @@ -76,18 +70,6 @@ def test_with_no_task_memory_and_empty_tool_output_memory(self): assert agent.tools[0].input_memory[0] == agent.task_memory assert agent.tools[0].output_memory == {} - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - agent = Agent(tools=[MockTool()], embedding_driver=embedding_driver) - - storage = list(agent.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_without_default_task_memory(self): agent = Agent(task_memory=None, tools=[MockTool()]) @@ -95,7 +77,7 @@ def test_without_default_task_memory(self): assert agent.tools[0].output_memory is None def test_with_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + agent = Agent(conversation_memory=ConversationMemory()) assert agent.conversation_memory is not None assert len(agent.conversation_memory.runs) == 0 @@ -117,7 +99,7 @@ def test_tasks_initialization(self): assert agent.tasks[0] == task def test_add_task(self): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() assert len(agent.tasks) == 1 @@ -145,7 +127,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() try: agent.add_tasks(first_task, second_task) @@ -160,7 +142,7 @@ def test_add_tasks(self): assert True def test_prompt_stack_without_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None, rules=[Rule("test")]) + agent = Agent(conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") @@ -177,7 +159,7 @@ def test_prompt_stack_without_memory(self): assert len(task1.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory(), rules=[Rule("test")]) + agent = Agent(conversation_memory=ConversationMemory(), rules=[Rule("test")]) task1 = PromptTask("test") @@ -195,7 +177,7 @@ def test_prompt_stack_with_memory(self): def test_run(self): task = PromptTask("test") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) assert task.state == BaseTask.State.PENDING @@ -207,7 +189,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent._execution_args = ("test1", "test2") @@ -220,7 +202,7 @@ def test_run_with_args(self): def test_context(self): task = PromptTask("test prompt") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) @@ -230,37 +212,9 @@ def test_context(self): assert context["structure"] == agent - def test_task_memory_defaults(self): - prompt_driver = MockPromptDriver() - embedding_driver = MockEmbeddingDriver() - agent = Agent(prompt_driver=prompt_driver, embedding_driver=embedding_driver) - - storage = list(agent.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - - assert storage.rag_engine.response_stage.response_module.prompt_driver == prompt_driver - assert ( - storage.rag_engine.retrieval_stage.retrieval_modules[0].vector_store_driver.embedding_driver - == embedding_driver - ) - assert isinstance(storage.summary_engine, PromptSummaryEngine) - assert storage.summary_engine.prompt_driver == prompt_driver - assert storage.csv_extraction_engine.prompt_driver == prompt_driver - assert storage.json_extraction_engine.prompt_driver == prompt_driver - - def test_deprecation(self): - with pytest.deprecated_call(): - Agent(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Agent(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Agent(stream=True) - def finished_tasks(self): task = PromptTask("test prompt") - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) @@ -270,4 +224,4 @@ def finished_tasks(self): def test_fail_fast(self): with pytest.raises(ValueError): - Agent(prompt_driver=MockPromptDriver(), fail_fast=True) + Agent(fail_fast=True) diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 306fd7bd2..a7f7f40c1 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -4,14 +4,11 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from griptape.tokenizers import OpenAiTokenizer -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.unit.structures.test_agent import MockEmbeddingDriver class TestPipeline: @@ -31,10 +28,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + pipeline = Pipeline(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert pipeline.prompt_driver is driver assert pipeline.input_task is None assert pipeline.output_task is None assert pipeline.rulesets[0].name == "TestRuleset" @@ -103,20 +98,6 @@ def test_with_task_memory(self): assert pipeline.tasks[0].tools[0].output_memory is not None assert pipeline.tasks[0].tools[0].output_memory["test"][0] == pipeline.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - pipeline = Pipeline(embedding_driver=embedding_driver) - - pipeline.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(pipeline.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): pipeline = Pipeline() @@ -139,7 +120,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline + [first_task, second_task, third_task] @@ -174,7 +155,7 @@ def test_tasks_order(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -189,7 +170,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -208,7 +189,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] @@ -227,7 +208,7 @@ def test_insert_task_in_middle(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(first_task, third_task) @@ -251,7 +232,7 @@ def test_insert_task_at_end(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(second_task, third_task) @@ -271,7 +252,7 @@ def test_insert_task_at_end(self): assert [child.id for child in third_task.children] == [] def test_prompt_stack_without_memory(self): - pipeline = Pipeline(conversation_memory=None, prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -292,7 +273,7 @@ def test_prompt_stack_without_memory(self): assert len(task2.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -321,7 +302,7 @@ def test_text_artifact_token_count(self): def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task assert task.state == BaseTask.State.PENDING @@ -333,7 +314,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [task] pipeline._execution_args = ("test1", "test2") @@ -348,7 +329,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [parent, task, child] @@ -365,35 +346,23 @@ def test_context(self): assert context["parent"] == parent assert context["child"] == child - def test_deprecation(self): - with pytest.deprecated_call(): - Pipeline(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Pipeline(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Pipeline(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task]) pipeline.run() assert pipeline.output is None def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) pipeline.run() assert pipeline.output is not None def test_add_duplicate_task(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline + task @@ -402,7 +371,7 @@ def test_add_duplicate_task(self): def test_add_duplicate_task_directly(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline.tasks.append(task) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 242de29c5..79c9868e1 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -4,12 +4,9 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Workflow from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -30,10 +27,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + workflow = Workflow(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert workflow.prompt_driver is driver assert len(workflow.tasks) == 0 assert workflow.rulesets[0].name == "TestRuleset" assert workflow.rulesets[0].rules[0].value == "test" @@ -100,20 +95,6 @@ def test_with_task_memory(self): assert workflow.tasks[0].tools[0].output_memory is not None assert workflow.tasks[0].tools[0].output_memory["test"][0] == workflow.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - workflow = Workflow(embedding_driver=embedding_driver) - - workflow.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(workflow.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): workflow = Workflow() @@ -136,7 +117,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - workflow = Workflow(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + workflow = Workflow(conversation_memory=ConversationMemory()) workflow + [first_task, second_task, third_task] @@ -170,7 +151,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + first_task workflow.add_task(second_task) @@ -189,7 +170,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [first_task, second_task] @@ -206,7 +187,7 @@ def test_add_tasks(self): def test_run(self): task1 = PromptTask("test") task2 = PromptTask("test") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [task1, task2] assert task1.state == BaseTask.State.PENDING @@ -219,7 +200,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task workflow._execution_args = ("test1", "test2") @@ -241,7 +222,7 @@ def test_run_with_args(self): ], ) def test_run_raises_on_missing_parent_or_child_id(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -250,7 +231,6 @@ def test_run_raises_on_missing_parent_or_child_id(self, tasks): def test_run_topology_1_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), @@ -265,7 +245,6 @@ def test_run_topology_1_declarative_parents(self): def test_run_topology_1_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task2", "task3"]), PromptTask("test2", id="task2", child_ids=["task4"]), @@ -280,7 +259,6 @@ def test_run_topology_1_declarative_children(self): def test_run_topology_1_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task3"]), PromptTask("test2", id="task2", parent_ids=["task1"], child_ids=["task4"]), @@ -301,7 +279,7 @@ def test_run_topology_1_imperative_parents(self): task2.add_parent(task1) task3.add_parent("task1") task4.add_parents([task2, "task3"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -315,14 +293,14 @@ def test_run_topology_1_imperative_children(self): task1.add_children([task2, task3]) task2.add_child(task4) task3.add_child(task4) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() self._validate_topology_1(workflow) def test_run_topology_1_imperative_parents_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -336,7 +314,7 @@ def test_run_topology_1_imperative_parents_structure_init(self): self._validate_topology_1(workflow) def test_run_topology_1_imperative_children_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1", structure=workflow) task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -356,7 +334,7 @@ def test_run_topology_1_imperative_mixed(self): task4 = PromptTask("test4", id="task4") task1.add_children([task2, task3]) task4.add_parents([task2, task3]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -367,7 +345,7 @@ def test_run_topology_1_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 splits into task2 and task3 # task2 and task3 converge into task4 @@ -384,7 +362,7 @@ def test_run_topology_1_missing_parent(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 never added to workflow workflow + task4 @@ -396,7 +374,7 @@ def test_run_topology_1_id_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task4 never added to workflow workflow + task1 @@ -410,7 +388,7 @@ def test_run_topology_1_object_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -419,7 +397,6 @@ def test_run_topology_1_object_equality(self): def test_run_topology_2_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska"), PromptTask("testb", id="taskb", parent_ids=["taska"]), @@ -435,7 +412,6 @@ def test_run_topology_2_declarative_parents(self): def test_run_topology_2_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska", child_ids=["taskb", "taskc", "taskd", "taske"]), PromptTask("testb", id="taskb", child_ids=["taskd"]), @@ -459,7 +435,7 @@ def test_run_topology_2_imperative_parents(self): taskc.add_parent("taska") taskd.add_parents([taska, taskb, taskc]) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -475,7 +451,7 @@ def test_run_topology_2_imperative_children(self): taskb.add_child(taskd) taskc.add_children([taskd, taske]) taskd.add_child(taske) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -491,7 +467,7 @@ def test_run_topology_2_imperative_mixed(self): taskb.add_child(taskd) taskd.add_parent(taskc) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -503,7 +479,7 @@ def test_run_topology_2_imperative_insert(self): taskc = PromptTask("testc", id="taskc") taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_task(taska) workflow.add_task(taske) taske.add_parent(taska) @@ -517,7 +493,6 @@ def test_run_topology_2_imperative_insert(self): def test_run_topology_3_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"]), @@ -532,7 +507,6 @@ def test_run_topology_3_declarative_parents(self): def test_run_topology_3_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task4"]), PromptTask("test2", id="task2", child_ids=["task3"]), @@ -547,7 +521,6 @@ def test_run_topology_3_declarative_children(self): def test_run_topology_3_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"], child_ids=["task3"]), @@ -565,7 +538,7 @@ def test_run_topology_3_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task2 @@ -580,7 +553,6 @@ def test_run_topology_3_imperative_insert(self): def test_run_topology_4_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"]), @@ -600,7 +572,6 @@ def test_run_topology_4_declarative_parents(self): def test_run_topology_4_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info", child_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), PromptTask(id="movie_info_1", child_ids=["compare_movies"]), @@ -620,7 +591,6 @@ def test_run_topology_4_declarative_children(self): def test_run_topology_4_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), @@ -650,7 +620,7 @@ def test_run_topology_4_imperative_insert(self): publish_website = PromptTask(id="publish_website") movie_info_3 = PromptTask(id="movie_info_3") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_tasks(collect_movie_info, summarize_to_slack) workflow.insert_tasks(collect_movie_info, [movie_info_1, movie_info_2, movie_info_3], summarize_to_slack) workflow.insert_tasks([movie_info_1, movie_info_2, movie_info_3], compare_movies, summarize_to_slack) @@ -672,7 +642,7 @@ def test_run_topology_4_imperative_insert(self): ], ) def test_run_raises_on_cycle(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -684,7 +654,7 @@ def test_input_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -697,7 +667,7 @@ def test_output_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -709,7 +679,7 @@ def test_output_task(self): task1.add_children([task2, task3]) # task4 is the final task, but its defined at index 0 - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3]) + workflow = Workflow(tasks=[task4, task1, task2, task3]) # output_task topologically should be task4 assert task4 == workflow.output_task @@ -719,7 +689,7 @@ def test_to_graph(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -736,7 +706,7 @@ def test_order_tasks(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -753,7 +723,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + parent workflow + task @@ -776,20 +746,10 @@ def test_context(self): assert context["parents"] == {parent.id: parent} assert context["children"] == {child.id: child} - def test_deprecation(self): - with pytest.deprecated_call(): - Workflow(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Workflow(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Workflow(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task]) workflow.run() assert workflow.output is None @@ -797,9 +757,7 @@ def test_run_with_error_artifact(self, error_artifact_task, waiting_task): def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) workflow.run() assert workflow.output is not None diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 4cc860c32..33405ad10 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -6,7 +6,6 @@ from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestAudioTranscriptionTask: @@ -41,7 +40,7 @@ def test_run(self, audio_artifact, audio_transcription_engine): audio_transcription_engine.run.return_value = TextArtifact("mock transcription") task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert pipeline.run().output.to_text() == "mock transcription" diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index 3d8d67a55..8eaa832ae 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Pipeline from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBaseMultiTextInputTask: @@ -42,7 +41,7 @@ def test_full_context(self): parent = MockMultiTextInputTask(("parent1", "parent2")) subtask = MockMultiTextInputTask(("test1", "test2"), context={"foo": "bar"}) child = MockMultiTextInputTask(("child2", "child2")) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d22ef35f7 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -7,8 +7,6 @@ from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask from tests.mocks.mock_tool.tool import MockTool @@ -18,10 +16,9 @@ class TestBaseTask: def task(self): EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( - prompt_driver=MockPromptDriver(), - embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], ) + Config.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 86dc98805..ff6afe42b 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_text_input_task import MockTextInputTask @@ -31,7 +30,7 @@ def test_full_context(self): parent = MockTextInputTask("parent") subtask = MockTextInputTask("test", context={"foo": "bar"}) child = MockTextInputTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index 3178e29db..e2c492fad 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.structures import Pipeline from griptape.tasks import CodeExecutionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver def hello_world(task: CodeExecutionTask) -> BaseArtifact: @@ -27,7 +26,7 @@ def test_hello_world_fn(self): # Using a Pipeline # Overriding the input because we are implementing the task not the Pipeline def test_noop_fn(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() task = CodeExecutionTask("No Op", run_fn=non_outputting) pipeline.add_task(task) temp = task.run() diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index afa73a506..76a4c3bd2 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -3,15 +3,12 @@ from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestExtractionTask: @pytest.fixture() def task(self): - return ExtractionTask( - extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), args={"column_names": ["test1"]} - ) + return ExtractionTask(extraction_engine=CsvExtractionEngine(), args={"column_names": ["test1"]}) def test_run(self, task): agent = Agent() diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 4a618e0d1..cfe853226 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -12,7 +12,7 @@ class TestPromptTask: def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index b205d385a..f70a61bdd 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -5,7 +5,6 @@ from griptape.engines.rag.stages import ResponseRagStage from griptape.structures import Agent from griptape.tasks import RagTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestRagTask: @@ -13,11 +12,7 @@ class TestRagTask: def task(self): return RagTask( input="test", - rag_engine=RagEngine( - response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver()) - ) - ), + rag_engine=RagEngine(response_stage=ResponseRagStage(response_module=PromptResponseRagModule())), ) def test_run(self, task): diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 1053ade9e..8df0e6598 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -5,9 +5,11 @@ class TestStructureRunTask: - def test_run(self): - agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) - pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + def test_run(self, mock_config): + mock_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") + agent = Agent() + mock_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) task = StructureRunTask(driver=driver) diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index 438d2bae4..f83075f2a 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -6,7 +6,7 @@ class TestTextSummaryTask: def test_run(self): - task = TextSummaryTask("test", summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver())) + task = TextSummaryTask("test", summary_engine=PromptSummaryEngine()) agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 3c629c69d..44348fef0 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -4,7 +4,6 @@ from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestTextToSpeechTask: @@ -40,7 +39,7 @@ def test_run(self): text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert isinstance(pipeline.run().output, AudioArtifact) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index dfc679919..90a7075fa 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -5,7 +5,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -166,13 +165,12 @@ class TestToolTask: } @pytest.fixture() - def agent(self): + def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - return Agent( - prompt_driver=MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}"), - embedding_driver=MockEmbeddingDriver(), - ) + mock_config.prompt_driver = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + + return Agent() def test_run_without_memory(self, agent): task = ToolTask(tool=MockTool()) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index a47e4687b..1b89ddf70 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -2,7 +2,6 @@ from griptape.common import ToolAction from griptape.structures import Agent from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -175,7 +174,7 @@ def test_run(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -190,7 +189,7 @@ def test_run_max_subtasks(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -204,7 +203,7 @@ def test_run_invalid_react_prompt(self, mock_config): mock_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py index d498b7c56..ee76d4da1 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_client.py @@ -3,14 +3,12 @@ from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver from griptape.structures import Agent from griptape.tools import StructureRunClient -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureRunClient: @pytest.fixture() def client(self): - driver = MockPromptDriver() - agent = Agent(prompt_driver=driver) + agent = Agent() return StructureRunClient( description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 42ecc59c3..5f97d1baf 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -1,14 +1,13 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Agent from griptape.utils import Chat -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_init(self): import logging - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + agent = Agent(conversation_memory=ConversationMemory()) chat = Chat( agent, diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index 28ee72409..a07d15cdb 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -2,12 +2,11 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from griptape.utils import Conversation -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_lines(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -22,7 +21,7 @@ def test_lines(self): assert lines[3] == "A: mock output" def test_prompt_stack_conversation_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -36,8 +35,7 @@ def test_prompt_stack_conversation_memory(self): def test_prompt_stack_summary_conversation_memory(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), - conversation_memory=SummaryConversationMemory(summary="foobar", prompt_driver=MockPromptDriver()), + conversation_memory=SummaryConversationMemory(summary="foobar"), ) pipeline.add_tasks(PromptTask("question 1")) @@ -52,7 +50,7 @@ def test_prompt_stack_summary_conversation_memory(self): assert lines[2] == "assistant: mock output" def test___str__(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index a9c122126..00df6958d 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -3,7 +3,6 @@ from griptape import utils from griptape.loaders import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -32,7 +31,7 @@ def test_load_files(self): def test_load_file_with_loader(self): dirname = os.path.dirname(__file__) file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file) + artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) assert len(artifacts) == 39 assert isinstance(artifacts, list) @@ -43,7 +42,7 @@ def test_load_files_with_loader(self): sources = ["resources/foobar-many.txt"] sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + loader = TextLoader(max_tokens=MAX_TOKENS) collection = loader.load_collection(list(files.values())) test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index f6e621b91..8a055cb21 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -1,12 +1,11 @@ from griptape.structures import Agent, Pipeline, Workflow from griptape.tasks import PromptTask from griptape.utils import StructureVisualizer -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureVisualizer: def test_agent(self): - agent = Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask("test1", id="task1")]) + agent = Agent(tasks=[PromptTask("test1", id="task1")]) visualizer = StructureVisualizer(agent) result = visualizer.to_url() @@ -15,7 +14,6 @@ def test_agent(self): def test_pipeline(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), @@ -34,7 +32,6 @@ def test_pipeline(self): def test_workflow(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index bad7f0d79..e3bcde29b 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -17,9 +17,9 @@ def text_tool_artifact_storage(): rag_engine=rag_engine(MockPromptDriver(), vector_store_driver), vector_store_driver=vector_store_driver, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver()), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), - json_extraction_engine=JsonExtractionEngine(prompt_driver=MockPromptDriver()), + summary_engine=PromptSummaryEngine(), + csv_extraction_engine=CsvExtractionEngine(), + json_extraction_engine=JsonExtractionEngine(), ) diff --git a/tests/utils/test_reference_utils.py b/tests/utils/test_reference_utils.py index c3491f5d0..47da18713 100644 --- a/tests/utils/test_reference_utils.py +++ b/tests/utils/test_reference_utils.py @@ -1,12 +1,11 @@ from griptape.artifacts import TextArtifact from griptape.common import Reference from griptape.engines.rag.modules import PromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestReferenceUtils: def test_references_from_artifacts(self): - module = PromptResponseRagModule(prompt_driver=MockPromptDriver()) + module = PromptResponseRagModule() reference1 = Reference(title="foo") reference2 = Reference(title="bar") artifacts = [ From e8c1fff7155dcdfd6ecf3b8dcdf5322ff3f48a49 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:21:06 -0700 Subject: [PATCH 05/40] Namespace config --- CHANGELOG.md | 4 ++-- griptape/config/base_structure_config.py | 9 +-------- griptape/config/config.py | 16 +++++++++++++++- .../engines/audio/audio_transcription_engine.py | 2 +- griptape/engines/audio/text_to_speech_engine.py | 2 +- .../engines/extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 4 +++- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- .../memory/structure/base_conversation_memory.py | 4 ++-- .../structure/summary_conversation_memory.py | 2 +- .../memory/task/storage/text_artifact_storage.py | 2 +- griptape/structures/structure.py | 8 ++++---- griptape/tasks/prompt_task.py | 2 +- griptape/utils/chat.py | 4 ++-- griptape/utils/stream.py | 2 +- tests/unit/conftest.py | 9 +-------- .../drivers/prompt/test_base_prompt_driver.py | 6 +++--- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../memory/structure/test_conversation_memory.py | 8 +++++--- tests/unit/tasks/test_json_extraction_task.py | 4 +++- tests/unit/tasks/test_structure_run_task.py | 4 ++-- tests/unit/tasks/test_tool_task.py | 4 +++- tests/unit/tasks/test_toolkit_task.py | 6 +++--- tests/unit/utils/test_stream.py | 4 ++-- 28 files changed, 64 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..c0720ca47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -263,7 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.drivers.global_drivers`. Pass Drivers directly to the Structure Config instead. - **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.drivers.global_drivers` by default. ## [0.23.1] - 2024-03-07 diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index 84743c4da..bc9238df2 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -6,7 +6,6 @@ from attrs import define, field from griptape.config import BaseConfig -from griptape.utils import dict_merge if TYPE_CHECKING: from griptape.drivers import ( @@ -22,7 +21,7 @@ @define -class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): +class BaseStructureConfig(BaseConfig, ABC): prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) @@ -35,9 +34,3 @@ class BaseStructureConfig(BaseConfig, ABC, EventPublisherMixin): ) text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) - - def merge_config(self, config: dict) -> BaseStructureConfig: - base_config = self.to_dict() - merged_config = dict_merge(base_config, config) - - return BaseStructureConfig.from_dict(merged_config) diff --git a/griptape/config/config.py b/griptape/config/config.py index e3017f8b6..3985abca2 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,3 +1,17 @@ +from attrs import define + +from griptape.config.base_config import BaseConfig +from griptape.config.base_structure_config import BaseStructureConfig +from griptape.mixins.event_publisher_mixin import EventPublisherMixin + from .openai_structure_config import OpenAiStructureConfig -Config = OpenAiStructureConfig() + +@define +class _Config(BaseConfig, EventPublisherMixin): + drivers: BaseStructureConfig + + +Config = _Config( + drivers=OpenAiStructureConfig(), +) diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index a3769842d..aad669d70 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.audio_transcription_driver), kw_only=True + default=Factory(lambda: Config.drivers.audio_transcription_driver), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 361ecc127..16634ce45 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.text_to_speech_driver), kw_only=True + default=Factory(lambda: Config.drivers.text_to_speech_driver), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 3ff6a96e3..03826ab43 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index eabf38be3..4187dde79 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.image_generation_driver) + kw_only=True, default=Factory(lambda: Config.drivers.image_generation_driver) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index ed6a64ee3..5090e2f27 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -13,7 +13,9 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.image_query_driver), kw_only=True) + image_query_driver: BaseImageQueryDriver = field( + default=Factory(lambda: Config.drivers.image_query_driver), kw_only=True + ) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 2e7b486b6..979723beb 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index b0deca67d..392a6836d 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index d06ebaa2f..2586a8e0c 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index fb1cfdd8b..3c3a0aaca 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.conversation_memory_driver), kw_only=True + default=Factory(lambda: Config.drivers.conversation_memory_driver), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.prompt_driver + prompt_driver = Config.drivers.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 807775d63..161a68eb3 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.prompt_driver)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8a918c5f2..134274648 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 73b5e617a..010e8ef1f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -140,10 +140,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=Config.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.prompt_driver), + vector_store_driver=Config.drivers.vector_store_driver, + summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt_driver), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt_driver), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt_driver), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 9f698787f..6997c9558 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -17,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index a8bdc9b13..6455efd14 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -27,7 +27,7 @@ class Chat: def default_output_fn(self, text: str) -> None: from griptape.config import Config - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 7b5381202..7c716787b 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -36,7 +36,7 @@ class Stream: def validate_structure(self, _: Attribute, structure: Structure) -> None: from griptape.config import Config - if not Config.prompt_driver.stream: + if not Config.drivers.prompt_driver.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7d2f8203d..e49de0021 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,13 +16,6 @@ def event_bus(): @pytest.fixture(autouse=True) def mock_config(): - mock_structure_config = MockStructureConfig() - Config.prompt_driver = mock_structure_config.prompt_driver - Config.image_generation_driver = mock_structure_config.image_generation_driver - Config.image_query_driver = mock_structure_config.image_query_driver - Config.embedding_driver = mock_structure_config.embedding_driver - Config.vector_store_driver = mock_structure_config.vector_store_driver - Config.text_to_speech_driver = mock_structure_config.text_to_speech_driver - Config.audio_transcription_driver = mock_structure_config.audio_transcription_driver + Config.drivers = MockStructureConfig() return Config diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index d95e7a5a7..84fd0bed1 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -47,7 +47,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 318a41aa2..b2e9c069b 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index ed978db78..d2681877f 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 77cebf193..06e54e6c4 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,7 +119,9 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + mock_config.drivers.prompt_driver = MockPromptDriver( + tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) + ) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -143,7 +145,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 0189e6679..3eef4eec3 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -13,7 +13,9 @@ def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) def test_run(self, task, mock_config): - mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + mock_config.drivers.prompt_driver.mock_output = ( + '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + ) agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 8df0e6598..2c0dc1b28 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 90a7075fa..70ab05e12 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,7 +168,9 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.prompt_driver = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + mock_config.drivers.prompt_driver = MockPromptDriver( + mock_output=f"```python foo bar\n{json.dumps(output_dict)}" + ) return Agent() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 1b89ddf70..15f5a59b1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.prompt_driver.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 555daa4fd..48dbaae29 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.prompt_driver.stream = request.param + Config.drivers.prompt_driver.stream = request.param return Agent() def test_init(self, agent): - if Config.prompt_driver.stream: + if Config.drivers.prompt_driver.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent From 514665f8e7bdd9582ea1332045971a5fa563b21e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:35:52 -0700 Subject: [PATCH 06/40] Rename driver config fields --- docs/examples/multiple-agent-shared-memory.md | 9 ++- docs/examples/talk-to-a-video.md | 4 +- .../drivers/embedding-drivers.md | 6 +- .../drivers/event-listener-drivers.md | 4 +- .../drivers/prompt-drivers.md | 61 ++++++++--------- docs/griptape-framework/misc/events.md | 10 +-- docs/griptape-framework/structures/config.md | 66 +++++++++---------- .../structures/task-memory.md | 4 +- .../official-tools/rest-api-client.md | 6 +- griptape/config/__init__.py | 32 ++++----- ...fig.py => amazon_bedrock_driver_config.py} | 16 ++--- ...e_config.py => anthropic_driver_config.py} | 12 ++-- ...onfig.py => azure_openai_driver_config.py} | 16 ++--- griptape/config/base_driver_config.py | 34 ++++++++++ griptape/config/base_structure_config.py | 36 ---------- ...ture_config.py => cohere_driver_config.py} | 12 ++-- griptape/config/config.py | 8 +-- .../{structure_config.py => driver_config.py} | 20 +++--- ...ture_config.py => google_driver_config.py} | 10 +-- ...ture_config.py => openai_driver_config.py} | 18 ++--- .../audio/audio_transcription_engine.py | 2 +- .../engines/audio/text_to_speech_engine.py | 2 +- .../extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 4 +- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- griptape/exceptions/dummy_exception.py | 2 +- .../structure/base_conversation_memory.py | 4 +- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/text_artifact_storage.py | 2 +- griptape/structures/structure.py | 8 +-- griptape/tasks/prompt_task.py | 2 +- griptape/utils/chat.py | 4 +- griptape/utils/stream.py | 2 +- ...ucture_config.py => mock_driver_config.py} | 18 +++-- ...y => test_amazon_bedrock_driver_config.py} | 47 +++++++------ ...fig.py => test_anthropic_driver_config.py} | 24 +++---- ....py => test_azure_openai_driver_config.py} | 22 +++---- ...config.py => test_cohere_driver_config.py} | 22 +++---- tests/unit/config/test_driver_config.py | 39 +++++++++++ ...config.py => test_google_driver_config.py} | 24 +++---- ...config.py => test_openai_driver_config.py} | 24 +++---- tests/unit/config/test_structure_config.py | 39 ----------- tests/unit/conftest.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 6 +- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../structure/test_conversation_memory.py | 8 +-- tests/unit/tasks/test_json_extraction_task.py | 4 +- tests/unit/tasks/test_structure_run_task.py | 4 +- tests/unit/tasks/test_tool_task.py | 4 +- tests/unit/tasks/test_toolkit_task.py | 6 +- tests/unit/utils/test_stream.py | 4 +- tests/utils/structure_tester.py | 4 +- 56 files changed, 355 insertions(+), 380 deletions(-) rename griptape/config/{amazon_bedrock_structure_config.py => amazon_bedrock_driver_config.py} (84%) rename griptape/config/{anthropic_structure_config.py => anthropic_driver_config.py} (76%) rename griptape/config/{azure_openai_structure_config.py => azure_openai_driver_config.py} (89%) create mode 100644 griptape/config/base_driver_config.py delete mode 100644 griptape/config/base_structure_config.py rename griptape/config/{cohere_structure_config.py => cohere_driver_config.py} (76%) rename griptape/config/{structure_config.py => driver_config.py} (74%) rename griptape/config/{google_structure_config.py => google_driver_config.py} (75%) rename griptape/config/{openai_structure_config.py => openai_driver_config.py} (76%) rename tests/mocks/{mock_structure_config.py => mock_driver_config.py} (63%) rename tests/unit/config/{test_amazon_bedrock_structure_config.py => test_amazon_bedrock_driver_config.py} (71%) rename tests/unit/config/{test_anthropic_structure_config.py => test_anthropic_driver_config.py} (65%) rename tests/unit/config/{test_azure_openai_structure_config.py => test_azure_openai_driver_config.py} (84%) rename tests/unit/config/{test_cohere_structure_config.py => test_cohere_driver_config.py} (59%) create mode 100644 tests/unit/config/test_driver_config.py rename tests/unit/config/{test_google_structure_config.py => test_google_driver_config.py} (63%) rename tests/unit/config/{test_openai_structure_config.py => test_openai_driver_config.py} (81%) delete mode 100644 tests/unit/config/test_structure_config.py diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 109394d49..e6b092965 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -11,8 +11,7 @@ import os from griptape.tools import WebScraper, TaskMemoryClient from griptape.structures import Agent from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver -from griptape.config import AzureOpenAiStructureConfig - +from griptape.config import AzureOpenAiDriverConfig AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -26,7 +25,6 @@ MONGODB_INDEX_NAME = os.environ["MONGODB_INDEX_NAME"] MONGODB_VECTOR_PATH = os.environ["MONGODB_VECTOR_PATH"] MONGODB_CONNECTION_STRING = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE_NAME}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000" - embedding_driver = AzureOpenAiEmbeddingDriver( model='text-embedding-ada-002', azure_endpoint=AZURE_OPENAI_ENDPOINT_1, @@ -42,7 +40,7 @@ mongo_driver = AzureMongoDbVectorStoreDriver( vector_path=MONGODB_VECTOR_PATH, ) -config = AzureOpenAiStructureConfig( +config = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store_driver=mongo_driver, embedding_driver=embedding_driver, @@ -64,6 +62,7 @@ asker = Agent( ) if __name__ == "__main__": - loader.run("Load https://medium.com/enterprise-rag/a-first-intro-to-complex-rag-retrieval-augmented-generation-a8624d70090f") + loader.run( + "Load https://medium.com/enterprise-rag/a-first-intro-to-complex-rag-retrieval-augmented-generation-a8624d70090f") asker.run("why is retrieval augmented generation useful?") ``` diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index 9673bd1c3..310b6d407 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,7 +7,7 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig import google.generativeai as genai video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") @@ -19,7 +19,7 @@ if video_file.state.name == "FAILED": raise ValueError(video_file.state.name) agent = Agent( - config=GoogleStructureConfig(), + config=GoogleDriverConfig(), input=[ GenericArtifact(video_file), TextArtifact("Answer this question regarding the video: {{ args[0] }}"), diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 567aa13e4..de2f2d379 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -211,7 +211,7 @@ print(embeddings[:3]) ``` ### Override Default Structure Embedding Driver -Here is how you can override the Embedding Driver that is used by default in Structures. +Here is how you can override the Embedding Driver that is used by default in Structures. ```python from griptape.structures import Agent @@ -220,11 +220,11 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ), diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..0adb0b10f 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -123,7 +123,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -138,7 +138,7 @@ agent = Agent( value="You will be provided with a text, and your task is to extract the airport codes from it." ) ], - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index ab749bf7c..8693cc6ff 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -13,10 +13,10 @@ You can instantiate drivers and pass them to structures: from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", @@ -71,10 +71,10 @@ import os from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( api_key=os.environ["OPENAI_API_KEY"], temperature=0.1, @@ -106,10 +106,10 @@ Simply set the `base_url` to the service's API endpoint and the `model` to the m from griptape.structures import Agent from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True @@ -134,10 +134,10 @@ import os from griptape.structures import Agent from griptape.rules import Rule from griptape.drivers import AzureOpenAiChatPromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-3.5-turbo", @@ -168,10 +168,10 @@ This driver uses [Cohere tool use](https://docs.cohere.com/docs/tools) when usin import os from griptape.structures import Agent from griptape.drivers import CoherePromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=CoherePromptDriver( model="command-r", api_key=os.environ['COHERE_API_KEY'], @@ -194,10 +194,10 @@ This driver uses [Anthropic tool use](https://docs.anthropic.com/en/docs/build-w import os from griptape.structures import Agent from griptape.drivers import AnthropicPromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-opus-20240229", api_key=os.environ['ANTHROPIC_API_KEY'], @@ -220,10 +220,10 @@ This driver uses [Gemini function calling](https://ai.google.dev/gemini-api/docs import os from griptape.structures import Agent from griptape.drivers import GooglePromptDriver -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=GooglePromptDriver( model="gemini-pro", api_key=os.environ['GOOGLE_API_KEY'], @@ -248,10 +248,10 @@ All models supported by the Converse API are available for use with this driver. from griptape.structures import Agent from griptape.drivers import AmazonBedrockPromptDriver from griptape.rules import Rule -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AmazonBedrockPromptDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", ) @@ -285,14 +285,13 @@ The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_d This driver uses [Ollama tool calling](https://ollama.com/blog/tool-support) when using [Tools](../tools/index.md). ```python -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import OllamaPromptDriver from griptape.tools import Calculator from griptape.structures import Agent - agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=OllamaPromptDriver( model="llama3.1", ), @@ -319,11 +318,10 @@ import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFaceHubPromptDriver( model="HuggingFaceH4/zephyr-7b-beta", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], @@ -335,8 +333,8 @@ agent = Agent( rules=[ Rule( value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. " - "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." + "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " + "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." ) ], ) @@ -354,11 +352,10 @@ The [HuggingFaceHubPromptDriver](#hugging-face-hub) also supports [Text Generati import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFaceHubPromptDriver( model="http://127.0.0.1:8080", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], @@ -383,11 +380,10 @@ The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/hu from griptape.structures import Agent from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset -from griptape.config import StructureConfig - +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=HuggingFacePipelinePromptDriver( model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) @@ -417,7 +413,6 @@ The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prom Amazon Sagemaker Jumpstart provides a wide range of models with varying capabilities. This Driver has been primarily _chat-optimized_ models that have a [Huggingface Chat Template](https://huggingface.co/docs/transformers/en/chat_templating) available. If your model does not fit this use-case, we suggest sub-classing [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) and overriding the `_to_model_input` and `_to_model_params` methods. - ```python title="PYTEST_IGNORE" import os @@ -426,10 +421,10 @@ from griptape.drivers import ( AmazonSageMakerJumpstartPromptDriver, SageMakerFalconPromptModelDriver, ) -from griptape.config import StructureConfig +from griptape.config import DriverConfig agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AmazonSageMakerJumpstartPromptDriver( endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], model="meta-llama/Meta-Llama-3-8B-Instruct", diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..dfb6e2db3 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -135,7 +135,7 @@ from griptape.events import CompletionChunkEvent, EventListener, EventBus from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig from griptape.drivers import OpenAiChatPromptDriver @@ -148,7 +148,7 @@ EventBus.event_listeners = [ ] pipeline = Pipeline( - config=OpenAiStructureConfig( + config=OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True) ), ) @@ -172,10 +172,10 @@ from griptape.tasks import ToolkitTask from griptape.structures import Pipeline from griptape.tools import WebScraper, TaskMemoryClient - pipeline = Pipeline() -pipeline.config.prompt_driver.stream = True -pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)])) +pipeline.config.prompt.stream = True +pipeline.add_tasks(ToolkitTask("Based on https://griptape.ai, tell me what griptape is.", + tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)])) for artifact in Stream(pipeline).run(): print(artifact.value, end="", flush=True) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 3f510eb86..17fb9e5da 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -5,44 +5,42 @@ search: ## Overview -The [StructureConfig](../../reference/griptape/config/structure_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. +The [StructureConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. ### Premade Configs -Griptape provides predefined [StructureConfig](../../reference/griptape/config/structure_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. +Griptape provides predefined [StructureConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. #### OpenAI -The [OpenAI Structure Config](../../reference/griptape/config/openai_structure_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. - +The [OpenAI Structure Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig agent = Agent( - config=OpenAiStructureConfig() + config=OpenAiDriverConfig() ) -agent = Agent() # This is equivalent to the above +agent = Agent() # This is equivalent to the above ``` #### Azure OpenAI -The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_structure_config.md) provides default Drivers for Azure's OpenAI APIs. - +The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig agent = Agent( - config=AzureOpenAiStructureConfig( + config=AzureOpenAiDriverConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ).merge_config({ - "image_query_driver": { + "image_query": { "azure_deployment": "gpt-4o", }, }), @@ -50,16 +48,16 @@ agent = Agent( ``` #### Amazon Bedrock -The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_structure_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig agent = Agent( - config=AmazonBedrockStructureConfig( + config=AmazonBedrockDriverConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -70,63 +68,61 @@ agent = Agent( ``` #### Google -The [Google Structure Config](../../reference/griptape/config/google_structure_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Structure Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig agent = Agent( - config=GoogleStructureConfig() + config=GoogleDriverConfig() ) ``` #### Anthropic -The [Anthropic Structure Config](../../reference/griptape/config/anthropic_structure_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Structure Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. The `AnthropicStructureConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). - ```python from griptape.structures import Agent -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig agent = Agent( - config=AnthropicStructureConfig() + config=AnthropicDriverConfig() ) ``` #### Cohere -The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_config.md) provides default Drivers for Cohere's APIs. - +The [Cohere Structure Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig from griptape.structures import Agent -agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"])) +agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) ``` ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers. -The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +You can create your own [StructureConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. +The [StructureConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. ```python import os from griptape.structures import Agent -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import AnthropicPromptDriver agent = Agent( - config=StructureConfig( + config=DriverConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], @@ -141,14 +137,14 @@ Configuration classes in Griptape offer utility methods for loading, saving, and ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig from griptape.drivers import AmazonBedrockCohereEmbeddingDriver -custom_config = AmazonBedrockStructureConfig() +custom_config = AmazonBedrockDriverConfig() custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() custom_config.merge_config( { - "embedding_driver": { + "embedding": { "base_url": None, "model": "text-embedding-3-small", "organization": None, @@ -157,11 +153,11 @@ custom_config.merge_config( } ) serialized_config = custom_config.to_json() -deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config) +deserialized_config = AmazonBedrockDriverConfig.from_json(serialized_config) agent = Agent( config=deserialized_config.merge_config({ - "prompt_driver" : { + "prompt": { "model": "anthropic.claude-3-sonnet-20240229-v1:0", }, }), diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index ea4a787f6..3184c4096 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - OpenAiStructureConfig, + OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -223,7 +223,7 @@ from griptape.tools import FileManager, TaskMemoryClient, WebScraper vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) agent = Agent( - config=OpenAiStructureConfig( + config=OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ), task_memory=TaskMemory( diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 07ddccf86..304ec00ec 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -6,7 +6,7 @@ The [RestApiClient](../../reference/griptape/tools/rest_api_client/tool.md) tool ### Example The following example is built using [https://jsonplaceholder.typicode.com/guide/](https://jsonplaceholder.typicode.com/guide/). - + ```python from json import dumps from griptape.drivers import OpenAiChatPromptDriver @@ -14,7 +14,7 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import StructureConfig +from griptape.config import DriverConfig posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", @@ -117,7 +117,7 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), - config = StructureConfig( + config=DriverConfig( prompt_driver=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 4b0f8eb28..7450d7738 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -1,26 +1,26 @@ from .base_config import BaseConfig -from .base_structure_config import BaseStructureConfig +from .base_driver_config import BaseDriverConfig -from .structure_config import StructureConfig -from .openai_structure_config import OpenAiStructureConfig -from .azure_openai_structure_config import AzureOpenAiStructureConfig -from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig -from .anthropic_structure_config import AnthropicStructureConfig -from .google_structure_config import GoogleStructureConfig -from .cohere_structure_config import CohereStructureConfig +from .driver_config import DriverConfig +from .openai_driver_config import OpenAiDriverConfig +from .azure_openai_driver_config import AzureOpenAiDriverConfig +from .amazon_bedrock_driver_config import AmazonBedrockDriverConfig +from .anthropic_driver_config import AnthropicDriverConfig +from .google_driver_config import GoogleDriverConfig +from .cohere_driver_config import CohereDriverConfig from .config import Config __all__ = [ "BaseConfig", - "BaseStructureConfig", - "StructureConfig", - "OpenAiStructureConfig", - "AzureOpenAiStructureConfig", - "AmazonBedrockStructureConfig", - "AnthropicStructureConfig", - "GoogleStructureConfig", - "CohereStructureConfig", + "BaseDriverConfig", + "DriverConfig", + "OpenAiDriverConfig", + "AzureOpenAiDriverConfig", + "AmazonBedrockDriverConfig", + "AnthropicDriverConfig", + "GoogleDriverConfig", + "CohereDriverConfig", "Config", ] diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_driver_config.py similarity index 84% rename from griptape/config/amazon_bedrock_structure_config.py rename to griptape/config/amazon_bedrock_driver_config.py index 3ad7f8f48..a07300638 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -25,14 +25,14 @@ @define -class AmazonBedrockStructureConfig(StructureConfig): +class AmazonBedrockDriverConfig(DriverConfig): session: boto3.Session = field( default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True, metadata={"serializable": False}, ) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AmazonBedrockPromptDriver( session=self.session, @@ -43,7 +43,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1"), takes_self=True, @@ -51,7 +51,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageGenerationDriver( session=self.session, @@ -63,7 +63,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageGenerationDriver = field( + image_query: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageQueryDriver( session=self.session, @@ -75,8 +75,8 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_driver_config.py similarity index 76% rename from griptape/config/anthropic_structure_config.py rename to griptape/config/anthropic_driver_config.py index 1bb5bf49b..642a3fced 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AnthropicImageQueryDriver, AnthropicPromptDriver, @@ -14,25 +14,25 @@ @define -class AnthropicStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class AnthropicDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")), ), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_driver_config.py similarity index 89% rename from griptape/config/azure_openai_structure_config.py rename to griptape/config/azure_openai_driver_config.py index ce0303e34..ef965fa28 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, @@ -20,7 +20,7 @@ @define -class AzureOpenAiStructureConfig(StructureConfig): +class AzureOpenAiDriverConfig(DriverConfig): """Azure OpenAI Structure Configuration. Attributes: @@ -43,7 +43,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": False}, ) api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AzureOpenAiChatPromptDriver( model="gpt-4o", @@ -57,7 +57,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AzureOpenAiImageGenerationDriver( model="dall-e-2", @@ -72,7 +72,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory( lambda self: AzureOpenAiImageQueryDriver( model="gpt-4o", @@ -86,7 +86,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", @@ -100,8 +100,8 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), metadata={"serializable": True}, kw_only=True, ) diff --git a/griptape/config/base_driver_config.py b/griptape/config/base_driver_config.py new file mode 100644 index 000000000..46ff181d3 --- /dev/null +++ b/griptape/config/base_driver_config.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + + +@define +class BaseDriverConfig(ABC): + prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) + image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) + image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) + embedding: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) + vector_store: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) + conversation_memory: Optional[BaseConversationMemoryDriver] = field( + default=None, + kw_only=True, + metadata={"serializable": True}, + ) + text_to_speech: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) + audio_transcription: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py deleted file mode 100644 index bc9238df2..000000000 --- a/griptape/config/base_structure_config.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.config import BaseConfig - -if TYPE_CHECKING: - from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, - ) - - -@define -class BaseStructureConfig(BaseConfig, ABC): - prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) - image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) - embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) - vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, - kw_only=True, - metadata={"serializable": True}, - ) - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/cohere_structure_config.py b/griptape/config/cohere_driver_config.py similarity index 76% rename from griptape/config/cohere_structure_config.py rename to griptape/config/cohere_driver_config.py index 2e896b9b0..7195f550f 100644 --- a/griptape/config/cohere_structure_config.py +++ b/griptape/config/cohere_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,15 +12,15 @@ @define -class CohereStructureConfig(StructureConfig): +class CohereDriverConfig(DriverConfig): api_key: str = field(metadata={"serializable": False}, kw_only=True) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: CohereEmbeddingDriver( model="embed-english-v3.0", @@ -32,8 +32,8 @@ class CohereStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/config.py b/griptape/config/config.py index 3985abca2..f325d1265 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,17 +1,17 @@ from attrs import define from griptape.config.base_config import BaseConfig -from griptape.config.base_structure_config import BaseStructureConfig +from griptape.config.base_driver_config import BaseDriverConfig from griptape.mixins.event_publisher_mixin import EventPublisherMixin -from .openai_structure_config import OpenAiStructureConfig +from .openai_driver_config import OpenAiDriverConfig @define class _Config(BaseConfig, EventPublisherMixin): - drivers: BaseStructureConfig + drivers: BaseDriverConfig Config = _Config( - drivers=OpenAiStructureConfig(), + drivers=OpenAiDriverConfig(), ) diff --git a/griptape/config/structure_config.py b/griptape/config/driver_config.py similarity index 74% rename from griptape/config/structure_config.py rename to griptape/config/driver_config.py index d68b6e2e2..325591258 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import BaseStructureConfig +from griptape.config import BaseDriverConfig from griptape.drivers import ( DummyAudioTranscriptionDriver, DummyEmbeddingDriver, @@ -29,43 +29,43 @@ @define -class StructureConfig(BaseStructureConfig): - prompt_driver: BasePromptDriver = field( +class DriverConfig(BaseDriverConfig): + prompt: BasePromptDriver = field( kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True}, ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + conversation_memory: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: DummyAudioTranscriptionDriver()), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_driver_config.py similarity index 75% rename from griptape/config/google_structure_config.py rename to griptape/config/google_driver_config.py index 66ed90b4b..a1089f0ee 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,18 +12,18 @@ @define -class GoogleStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class GoogleDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-pro")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")), ), diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_driver_config.py similarity index 76% rename from griptape/config/openai_structure_config.py rename to griptape/config/openai_driver_config.py index 63806dfc9..35ccde43d 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseAudioTranscriptionDriver, BaseEmbeddingDriver, @@ -20,40 +20,40 @@ @define -class OpenAiStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class OpenAiDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")), ), kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: OpenAiTextToSpeechDriver(model="tts")), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: OpenAiAudioTranscriptionDriver(model="whisper-1")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index aad669d70..51022e47c 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.drivers.audio_transcription_driver), kw_only=True + default=Factory(lambda: Config.drivers.audio_transcription), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 16634ce45..a163c36fd 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.drivers.text_to_speech_driver), kw_only=True + default=Factory(lambda: Config.drivers.text_to_speech), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 03826ab43..a1bcbdee2 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 4187dde79..921d600c7 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.drivers.image_generation_driver) + kw_only=True, default=Factory(lambda: Config.drivers.image_generation) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index 5090e2f27..d85e6012d 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -13,9 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: Config.drivers.image_query_driver), kw_only=True - ) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.drivers.image_query), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 979723beb..8e421d792 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 392a6836d..4daa10e54 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 2586a8e0c..1c45fa5ea 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/exceptions/dummy_exception.py b/griptape/exceptions/dummy_exception.py index 815cb245f..172aeadc6 100644 --- a/griptape/exceptions/dummy_exception.py +++ b/griptape/exceptions/dummy_exception.py @@ -2,7 +2,7 @@ class DummyError(Exception): def __init__(self, dummy_class_name: str, dummy_method_name: str) -> None: message = ( f"You have attempted to use a {dummy_class_name}'s {dummy_method_name} method. " - "This likely originated from using a `StructureConfig` without providing a Driver required for this feature." + "This likely originated from using a `DriverConfig` without providing a Driver required for this feature." ) super().__init__(message) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 3c3a0aaca..e7c8ed488 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.drivers.conversation_memory_driver), kw_only=True + default=Factory(lambda: Config.drivers.conversation_memory), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.drivers.prompt_driver + prompt_driver = Config.drivers.prompt temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 161a68eb3..50be69a61 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt_driver)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 134274648..460581997 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 010e8ef1f..4db8228e3 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -140,10 +140,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.drivers.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt_driver), + vector_store_driver=Config.drivers.vector_store, + summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt), + json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 6997c9558..19580b642 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -17,7 +17,7 @@ @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 6455efd14..99b5a7dc3 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -27,7 +27,7 @@ class Chat: def default_output_fn(self, text: str) -> None: from griptape.config import Config - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 7c716787b..cb5266378 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -36,7 +36,7 @@ class Stream: def validate_structure(self, _: Attribute, structure: Structure) -> None: from griptape.config import Config - if not Config.drivers.prompt_driver.stream: + if not Config.drivers.prompt.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_driver_config.py similarity index 63% rename from tests/mocks/mock_structure_config.py rename to tests/mocks/mock_driver_config.py index 0b374449d..c7407b8bc 100644 --- a/tests/mocks/mock_structure_config.py +++ b/tests/mocks/mock_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -9,20 +9,18 @@ @define -class MockStructureConfig(StructureConfig): - prompt_driver: MockPromptDriver = field( - default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: MockImageGenerationDriver = field( +class MockDriverConfig(DriverConfig): + prompt: MockPromptDriver = field(default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True}) + image_generation: MockImageGenerationDriver = field( default=Factory(lambda: MockImageGenerationDriver(model="dall-e-2")), metadata={"serializable": True} ) - image_query_driver: MockImageQueryDriver = field( + image_query: MockImageQueryDriver = field( default=Factory(lambda: MockImageQueryDriver(model="gpt-4-vision-preview")), metadata={"serializable": True} ) - embedding_driver: MockEmbeddingDriver = field( + embedding: MockEmbeddingDriver = field( default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} ) - vector_store_driver: LocalVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: LocalVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), metadata={"serializable": True}, ) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_driver_config.py similarity index 71% rename from tests/unit/config/test_amazon_bedrock_structure_config.py rename to tests/unit/config/test_amazon_bedrock_driver_config.py index afe9b3720..4fdbfedbc 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_driver_config.py @@ -1,7 +1,7 @@ import boto3 import pytest -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig from tests.utils.aws import mock_aws_credentials @@ -13,11 +13,11 @@ def _run_before_and_after_tests(self): @pytest.fixture() def config(self): mock_aws_credentials() - return AmazonBedrockStructureConfig() + return AmazonBedrockDriverConfig() @pytest.fixture() def config_with_values(self): - return AmazonBedrockStructureConfig( + return AmazonBedrockDriverConfig( session=boto3.Session( aws_access_key_id="testing", aws_secret_access_key="testing", region_name="region-value" ) @@ -25,9 +25,9 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -40,13 +40,13 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -55,32 +55,31 @@ def test_to_dict(self, config): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AmazonBedrockStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AmazonBedrockDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() def test_from_dict_with_values(self, config_with_values): assert ( - AmazonBedrockStructureConfig.from_dict(config_with_values.to_dict()).to_dict() - == config_with_values.to_dict() + AmazonBedrockDriverConfig.from_dict(config_with_values.to_dict()).to_dict() == config_with_values.to_dict() ) def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -93,13 +92,13 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -108,15 +107,15 @@ def test_to_dict_with_values(self, config_with_values): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } assert config_with_values.session.region_name == "region-value" diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_driver_config.py similarity index 65% rename from tests/unit/config/test_anthropic_structure_config.py rename to tests/unit/config/test_anthropic_driver_config.py index 05519fa5e..654e7ddf3 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig class TestAnthropicStructureConfig: @@ -11,12 +11,12 @@ def _mock_anthropic(self, mocker): @pytest.fixture() def config(self): - return AnthropicStructureConfig() + return AnthropicDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "AnthropicStructureConfig", - "prompt_driver": { + "type": "AnthropicDriverConfig", + "prompt": { "type": "AnthropicPromptDriver", "temperature": 0.1, "max_tokens": 1000, @@ -26,18 +26,18 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": { "type": "AnthropicImageQueryDriver", "model": "claude-3-5-sonnet-20240620", "max_tokens": 256, }, - "embedding_driver": { + "embedding": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", @@ -45,10 +45,10 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AnthropicStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AnthropicDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_driver_config.py similarity index 84% rename from tests/unit/config/test_azure_openai_structure_config.py rename to tests/unit/config/test_azure_openai_driver_config.py index 810cb41a1..5c43f3522 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig class TestAzureOpenAiStructureConfig: @@ -10,7 +10,7 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return AzureOpenAiStructureConfig( + return AzureOpenAiDriverConfig( azure_endpoint="http://localhost:8080", azure_ad_token="test-token", azure_ad_token_provider=lambda: "test-provider", @@ -18,9 +18,9 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "type": "AzureOpenAiStructureConfig", + "type": "AzureOpenAiDriverConfig", "azure_endpoint": "http://localhost:8080", - "prompt_driver": { + "prompt": { "type": "AzureOpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -36,8 +36,8 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": "2024-02-01", "base_url": None, "image_size": "512x512", @@ -59,7 +59,7 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "base_url": None, "image_quality": "auto", "max_tokens": 256, @@ -70,7 +70,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -82,6 +82,6 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_driver_config.py similarity index 59% rename from tests/unit/config/test_cohere_structure_config.py rename to tests/unit/config/test_cohere_driver_config.py index 113a589ec..c056cabeb 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_driver_config.py @@ -1,22 +1,22 @@ import pytest -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig class TestCohereStructureConfig: @pytest.fixture() def config(self): - return CohereStructureConfig(api_key="api_key") + return CohereDriverConfig(api_key="api_key") def test_to_dict(self, config): assert config.to_dict() == { - "type": "CohereStructureConfig", - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - "prompt_driver": { + "type": "CohereDriverConfig", + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "prompt": { "type": "CoherePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -25,12 +25,12 @@ def test_to_dict(self, config): "force_single_step": False, "use_native_tools": True, }, - "embedding_driver": { + "embedding": { "type": "CohereEmbeddingDriver", "model": "embed-english-v3.0", "input_type": "search_document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/config/test_driver_config.py b/tests/unit/config/test_driver_config.py new file mode 100644 index 000000000..e5585de24 --- /dev/null +++ b/tests/unit/config/test_driver_config.py @@ -0,0 +1,39 @@ +import pytest + +from griptape.config import DriverConfig + + +class TestStructureConfig: + @pytest.fixture() + def config(self): + return DriverConfig() + + def test_to_dict(self, config): + assert config.to_dict() == { + "type": "DriverConfig", + "prompt": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "use_native_tools": False, + }, + "conversation_memory": None, + "embedding": {"type": "DummyEmbeddingDriver"}, + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "vector_store": { + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "type": "DummyVectorStoreDriver", + }, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + } + + def test_from_dict(self, config): + assert DriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + + def test_dot_update(self, config): + config.prompt.max_tokens = 10 + + assert config.prompt.max_tokens == 10 diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_driver_config.py similarity index 63% rename from tests/unit/config/test_google_structure_config.py rename to tests/unit/config/test_google_driver_config.py index e193cc983..53663caf0 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig class TestGoogleStructureConfig: @@ -10,12 +10,12 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return GoogleStructureConfig() + return GoogleDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "GoogleStructureConfig", - "prompt_driver": { + "type": "GoogleDriverConfig", + "prompt": { "type": "GooglePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -26,15 +26,15 @@ def test_to_dict(self, config): "tool_choice": "auto", "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "embedding_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "embedding": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", @@ -43,10 +43,10 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert GoogleStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert GoogleDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_driver_config.py similarity index 81% rename from tests/unit/config/test_openai_structure_config.py rename to tests/unit/config/test_openai_driver_config.py index 8969e0ad0..7af0b755a 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_driver_config.py @@ -1,6 +1,6 @@ import pytest -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig class TestOpenAiStructureConfig: @@ -10,12 +10,12 @@ def mock_openai(self, mocker): @pytest.fixture() def config(self): - return OpenAiStructureConfig() + return OpenAiDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "OpenAiStructureConfig", - "prompt_driver": { + "type": "OpenAiDriverConfig", + "prompt": { "type": "OpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -28,14 +28,14 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": None, "base_url": None, "image_size": "512x512", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "api_version": None, "base_url": None, "image_quality": "auto", @@ -55,7 +55,7 @@ def test_to_dict(self, config): "organization": None, "type": "OpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -64,7 +64,7 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": { + "text_to_speech": { "type": "OpenAiTextToSpeechDriver", "api_version": None, "base_url": None, @@ -73,7 +73,7 @@ def test_to_dict(self, config): "organization": None, "voice": "alloy", }, - "audio_transcription_driver": { + "audio_transcription": { "type": "OpenAiAudioTranscriptionDriver", "api_version": None, "base_url": None, @@ -83,4 +83,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert OpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert OpenAiDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py deleted file mode 100644 index cce97647e..000000000 --- a/tests/unit/config/test_structure_config.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest - -from griptape.config import StructureConfig - - -class TestStructureConfig: - @pytest.fixture() - def config(self): - return StructureConfig() - - def test_to_dict(self, config): - assert config.to_dict() == { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "use_native_tools": False, - }, - "conversation_memory_driver": None, - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - } - - def test_from_dict(self, config): - assert StructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - def test_dot_update(self, config): - config.prompt_driver.max_tokens = 10 - - assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e49de0021..9207bbc1c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -2,7 +2,7 @@ from griptape.config import Config from griptape.events import EventBus -from tests.mocks.mock_structure_config import MockStructureConfig +from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) @@ -16,6 +16,6 @@ def event_bus(): @pytest.fixture(autouse=True) def mock_config(): - Config.drivers = MockStructureConfig() + Config.drivers = MockDriverConfig() return Config diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 84fd0bed1..248c259e5 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.drivers.prompt = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers.prompt = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -47,7 +47,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers.prompt = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index b2e9c069b..c2bb45208 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers.prompt = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index d2681877f..038cb4508 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers.prompt = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 06e54e6c4..f0e4b0af3 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,9 +119,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver( - tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) - ) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -145,7 +143,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index 3eef4eec3..8f9278c3c 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -13,9 +13,7 @@ def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) def test_run(self, task, mock_config): - mock_config.drivers.prompt_driver.mock_output = ( - '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - ) + mock_config.drivers.prompt.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 2c0dc1b28..d18d75d75 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers.prompt = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers.prompt = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 70ab05e12..18521632e 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,9 +168,7 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.drivers.prompt_driver = MockPromptDriver( - mock_output=f"```python foo bar\n{json.dumps(output_dict)}" - ) + mock_config.drivers.prompt = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") return Agent() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 15f5a59b1..c1b91b1ed 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 48dbaae29..318f434c3 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.drivers.prompt_driver.stream = request.param + Config.drivers.prompt.stream = request.param return Agent() def test_init(self, agent): - if Config.drivers.prompt_driver.stream: + if Config.drivers.prompt.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 5b908065b..9fa5e559a 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -25,9 +25,7 @@ def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: return [ - prompt_driver_option.prompt_driver - for prompt_driver_option in prompt_drivers_options - if prompt_driver_option.enabled + prompt_driver_option.prompt for prompt_driver_option in prompt_drivers_options if prompt_driver_option.enabled ] From ef24d49211d9afd73da7111c3745dd3e551913fb Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:37:43 -0700 Subject: [PATCH 07/40] Revert changelog update --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0720ca47..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -263,7 +263,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.drivers.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. - **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.drivers.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 From ab6578315abc8cf7607f770b83e9ed93b344e076 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:39:51 -0700 Subject: [PATCH 08/40] Rename Structure Config to Driver Config --- CHANGELOG.md | 32 +++++++++---------- docs/griptape-framework/structures/config.md | 22 ++++++------- griptape/config/azure_openai_driver_config.py | 2 +- .../test_amazon_bedrock_driver_config.py | 2 +- .../config/test_anthropic_driver_config.py | 2 +- .../config/test_azure_openai_driver_config.py | 2 +- .../unit/config/test_cohere_driver_config.py | 2 +- tests/unit/config/test_driver_config.py | 2 +- .../unit/config/test_google_driver_config.py | 2 +- .../unit/config/test_openai_driver_config.py | 2 +- 10 files changed, 35 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..785a5d083 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,7 +112,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `GoogleWebSearchDriver` to web search with the Google Customsearch API. - `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK. - `ProxyWebScraperDriver` to web scrape using proxies. -- Parameter `session` on `AmazonBedrockStructureConfig`. +- Parameter `session` on `AmazonBedrockDriverConfig`. - Parameter `meta` on `TextArtifact`. - `VectorStoreClient` improvements: - `VectorStoreClient.query_params` dict for custom query params. @@ -155,7 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - **BREAKING**: All `futures_executor` fields renamed to `futures_executor_fn` and now accept callables instead of futures; wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. -- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. +- Default Prompt Driver model in `GoogleDriverConfig` to `gemini-1.5-pro`. ### Fixed - `CoherePromptDriver` to properly handle empty history. @@ -175,7 +175,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Base Tool schema so that `input` is optional when no Tool Activity schema is set. - Tool Task system prompt for better results with lower-end models. -- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicStructureConfig` and `AmazonBedrockStructureConfig.` +- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicDriverConfig` and `AmazonBedrockDriverConfig.` ## [0.27.0] - 2024-06-19 @@ -186,7 +186,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseTask.add_parents()` to add multiple parent tasks to a child task. - `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. - `CohereEmbeddingDriver` for using Cohere's embeddings API. -- `CohereStructureConfig` for providing Structures with quick Cohere configuration. +- `CohereDriverConfig` for providing Structures with quick Cohere configuration. - `AmazonSageMakerJumpstartPromptDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.custom_attributes` for setting custom attributes when invoking an endpoint. @@ -252,7 +252,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.26.0] - 2024-06-04 ### Added -- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. +- `AzureOpenAiDriverConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. - `AudioLoader` for loading audio content into an `AudioArtifact`. - `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. @@ -263,8 +263,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. -- **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. +- **BREAKING**: Removed `DriverConfig.global_drivers`. Pass Drivers directly to the Driver Config instead. +- **BREAKING**: Removed `DriverConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. - **BREAKING**: `AmazonSageMakerPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`. @@ -293,7 +293,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Default behavior of Event Listener Drivers to batch events. -- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. +- Default behavior of OpenAiDriverConfig to utilize `gpt-4o` for prompt_driver. ## [0.25.0] - 2024-05-06 @@ -359,7 +359,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `text-embedding-3-small` and `text-embedding-3-large` models. - `GooglePromptDriver` and `GoogleTokenizer` for use with `gemini-pro`. - `GoogleEmbeddingDriver` for use with `embedding-001`. -- `GoogleStructureConfig` for providing Structures with Google Prompt and Embedding Driver configuration. +- `GoogleDriverConfig` for providing Structures with Google Prompt and Embedding Driver configuration. - Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`. - Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`. - `top_k` and `top_p` parameters in `AnthropicPromptDriver`. @@ -369,7 +369,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura. - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. - `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models. -- `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. +- `AnthropicDriverConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. - `QdrantVectorStoreDriver` to integrate with Qdrant vector databases. ### Fixed @@ -380,9 +380,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`. - **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified - Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`. -- Default model of `OpenAiStructureConfig` to `text-embedding-3-small`. +- Default model of `OpenAiDriverConfig` to `text-embedding-3-small`. - `BaseTextLoader` to accept a `BaseChunker`. -- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. +- Default model of `AmazonBedrockDriverConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. - `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API. - `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256 @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. +- `DriverConfig.task_memory` not defaulting to using `DriverConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 @@ -408,9 +408,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AzureMongoDbVectorStoreDriver` for using CosmosDB with MongoDB vCore API. - `vector_path` field on `MongoDbAtlasVectorStoreDriver`. - `LeonardoImageGenerationDriver` supports image to image generation. -- `OpenAiStructureConfig` for providing Structures with all OpenAi Driver configuration. -- `AmazonBedrockStructureConfig` for providing Structures with all Amazon Bedrock Driver configuration. -- `StructureConfig` for building your own Structure configuration. +- `OpenAiDriverConfig` for providing Structures with all OpenAi Driver configuration. +- `AmazonBedrockDriverConfig` for providing Structures with all Amazon Bedrock Driver configuration. +- `DriverConfig` for building your own Structure configuration. - `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`. - `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`. - `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model. diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 17fb9e5da..917485837 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -5,15 +5,15 @@ search: ## Overview -The [StructureConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. ### Premade Configs -Griptape provides predefined [StructureConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. +Griptape provides predefined [DriverConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. #### OpenAI -The [OpenAI Structure Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. +The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent @@ -28,7 +28,7 @@ agent = Agent() # This is equivalent to the above #### Azure OpenAI -The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. +The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os @@ -48,7 +48,7 @@ agent = Agent( ``` #### Amazon Bedrock -The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os @@ -68,7 +68,7 @@ agent = Agent( ``` #### Google -The [Google Structure Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Driver Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent @@ -81,11 +81,11 @@ agent = Agent( #### Anthropic -The [Anthropic Structure Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. - The `AnthropicStructureConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). + The `AnthropicDriverConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). ```python @@ -99,7 +99,7 @@ agent = Agent( #### Cohere -The [Cohere Structure Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. +The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os @@ -111,8 +111,8 @@ agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. -The [StructureConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +You can create your own [DriverConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. ```python diff --git a/griptape/config/azure_openai_driver_config.py b/griptape/config/azure_openai_driver_config.py index ef965fa28..c987a31b5 100644 --- a/griptape/config/azure_openai_driver_config.py +++ b/griptape/config/azure_openai_driver_config.py @@ -21,7 +21,7 @@ @define class AzureOpenAiDriverConfig(DriverConfig): - """Azure OpenAI Structure Configuration. + """Azure OpenAI Driver Configuration. Attributes: azure_endpoint: The endpoint for the Azure OpenAI instance. diff --git a/tests/unit/config/test_amazon_bedrock_driver_config.py b/tests/unit/config/test_amazon_bedrock_driver_config.py index 4fdbfedbc..57a80809e 100644 --- a/tests/unit/config/test_amazon_bedrock_driver_config.py +++ b/tests/unit/config/test_amazon_bedrock_driver_config.py @@ -5,7 +5,7 @@ from tests.utils.aws import mock_aws_credentials -class TestAmazonBedrockStructureConfig: +class TestAmazonBedrockDriverConfig: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() diff --git a/tests/unit/config/test_anthropic_driver_config.py b/tests/unit/config/test_anthropic_driver_config.py index 654e7ddf3..a2ccbd25b 100644 --- a/tests/unit/config/test_anthropic_driver_config.py +++ b/tests/unit/config/test_anthropic_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import AnthropicDriverConfig -class TestAnthropicStructureConfig: +class TestAnthropicDriverConfig: @pytest.fixture(autouse=True) def _mock_anthropic(self, mocker): mocker.patch("anthropic.Anthropic") diff --git a/tests/unit/config/test_azure_openai_driver_config.py b/tests/unit/config/test_azure_openai_driver_config.py index 5c43f3522..3c88b859d 100644 --- a/tests/unit/config/test_azure_openai_driver_config.py +++ b/tests/unit/config/test_azure_openai_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import AzureOpenAiDriverConfig -class TestAzureOpenAiStructureConfig: +class TestAzureOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") diff --git a/tests/unit/config/test_cohere_driver_config.py b/tests/unit/config/test_cohere_driver_config.py index c056cabeb..9e8407d84 100644 --- a/tests/unit/config/test_cohere_driver_config.py +++ b/tests/unit/config/test_cohere_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import CohereDriverConfig -class TestCohereStructureConfig: +class TestCohereDriverConfig: @pytest.fixture() def config(self): return CohereDriverConfig(api_key="api_key") diff --git a/tests/unit/config/test_driver_config.py b/tests/unit/config/test_driver_config.py index e5585de24..dd3fd1a47 100644 --- a/tests/unit/config/test_driver_config.py +++ b/tests/unit/config/test_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import DriverConfig -class TestStructureConfig: +class TestDriverConfig: @pytest.fixture() def config(self): return DriverConfig() diff --git a/tests/unit/config/test_google_driver_config.py b/tests/unit/config/test_google_driver_config.py index 53663caf0..fb6cd23b5 100644 --- a/tests/unit/config/test_google_driver_config.py +++ b/tests/unit/config/test_google_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import GoogleDriverConfig -class TestGoogleStructureConfig: +class TestGoogleDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("google.generativeai.GenerativeModel") diff --git a/tests/unit/config/test_openai_driver_config.py b/tests/unit/config/test_openai_driver_config.py index 7af0b755a..55156730c 100644 --- a/tests/unit/config/test_openai_driver_config.py +++ b/tests/unit/config/test_openai_driver_config.py @@ -3,7 +3,7 @@ from griptape.config import OpenAiDriverConfig -class TestOpenAiStructureConfig: +class TestOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.OpenAI") From ecfa3583c0728c6f8822eb1269c37b7bc6596072 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 12:49:45 -0700 Subject: [PATCH 09/40] Rename doc fields --- docs/examples/multiple-agent-shared-memory.md | 4 +-- .../drivers/embedding-drivers.md | 4 +-- .../drivers/event-listener-drivers.md | 2 +- .../drivers/prompt-drivers.md | 26 +++++++++---------- docs/griptape-framework/misc/events.md | 2 +- docs/griptape-framework/structures/config.md | 2 +- .../official-tools/rest-api-client.md | 2 +- tests/utils/structure_tester.py | 16 +++++++----- 8 files changed, 30 insertions(+), 28 deletions(-) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index e6b092965..0fe589d7b 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -42,8 +42,8 @@ mongo_driver = AzureMongoDbVectorStoreDriver( config = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, - vector_store_driver=mongo_driver, - embedding_driver=embedding_driver, + vector_store=mongo_driver, + embedding=embedding_driver, ) loader = Agent( diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index de2f2d379..3c81cf8a9 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -225,8 +225,8 @@ from griptape.config import DriverConfig agent = Agent( tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), - embedding_driver=VoyageAiEmbeddingDriver(), + prompt=OpenAiChatPromptDriver(model="gpt-4o"), + embedding=VoyageAiEmbeddingDriver(), ), ) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 0adb0b10f..8d4f521aa 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -139,7 +139,7 @@ agent = Agent( ) ], config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) ), diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 8693cc6ff..367bff4ba 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -17,7 +17,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), + prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[ @@ -75,7 +75,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( api_key=os.environ["OPENAI_API_KEY"], temperature=0.1, model="gpt-4o", @@ -110,7 +110,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True ) @@ -138,7 +138,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AzureOpenAiChatPromptDriver( + prompt=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-3.5-turbo", azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], @@ -172,7 +172,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=CoherePromptDriver( + prompt=CoherePromptDriver( model="command-r", api_key=os.environ['COHERE_API_KEY'], ) @@ -198,7 +198,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AnthropicPromptDriver( + prompt=AnthropicPromptDriver( model="claude-3-opus-20240229", api_key=os.environ['ANTHROPIC_API_KEY'], ) @@ -224,7 +224,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=GooglePromptDriver( + prompt=GooglePromptDriver( model="gemini-pro", api_key=os.environ['GOOGLE_API_KEY'], ) @@ -252,7 +252,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AmazonBedrockPromptDriver( + prompt=AmazonBedrockPromptDriver( model="anthropic.claude-3-sonnet-20240229-v1:0", ) ), @@ -292,7 +292,7 @@ from griptape.structures import Agent agent = Agent( config=DriverConfig( - prompt_driver=OllamaPromptDriver( + prompt=OllamaPromptDriver( model="llama3.1", ), ), @@ -322,7 +322,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFaceHubPromptDriver( + prompt=HuggingFaceHubPromptDriver( model="HuggingFaceH4/zephyr-7b-beta", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ) @@ -356,7 +356,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFaceHubPromptDriver( + prompt=HuggingFaceHubPromptDriver( model="http://127.0.0.1:8080", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), @@ -384,7 +384,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=HuggingFacePipelinePromptDriver( + prompt=HuggingFacePipelinePromptDriver( model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) ), @@ -425,7 +425,7 @@ from griptape.config import DriverConfig agent = Agent( config=DriverConfig( - prompt_driver=AmazonSageMakerJumpstartPromptDriver( + prompt=AmazonSageMakerJumpstartPromptDriver( endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], model="meta-llama/Meta-Llama-3-8B-Instruct", ) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index dfb6e2db3..b7f118d98 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -149,7 +149,7 @@ EventBus.event_listeners = [ pipeline = Pipeline( config=OpenAiDriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True) + prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True) ), ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 917485837..13b6c001a 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -123,7 +123,7 @@ from griptape.drivers import AnthropicPromptDriver agent = Agent( config=DriverConfig( - prompt_driver=AnthropicPromptDriver( + prompt=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], ) diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 304ec00ec..675f77b6e 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -118,7 +118,7 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), config=DriverConfig( - prompt_driver=OpenAiChatPromptDriver( + prompt=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 ), diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 9fa5e559a..2b9f83b81 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -226,6 +226,15 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: + from griptape.config import Config + + Config.drivers.prompt = AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-4o", + azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], + response_format="json_object", + ) output_schema = Schema( { Literal("correct", description="Whether the output was correct or not."): bool, @@ -263,13 +272,6 @@ def verify_structure_output(self, structure) -> dict: ], ), ], - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-4o", - azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - response_format="json_object", - ), tasks=[ PromptTask( "\nTasks: {{ task_names }}" From 582f55c372c1a08611c8dc73f8f1072637940050 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:18:34 -0700 Subject: [PATCH 10/40] Move events into config --- CHANGELOG.md | 4 ++-- griptape/config/base_config.py | 6 +++++- griptape/config/base_driver_config.py | 4 +++- griptape/config/config.py | 17 +++++++---------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 785a5d083..1b224bdd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseVectorStoreDriver.load_artifacts` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.upsert_vector` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments. -- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument. -- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `EventListener.events.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `BaseEventListenerDriver.events.publish_event`'s `flush` argument is now a keyword-only argument. - **BREAKING**: Renamed `DummyException` to `DummyError` for pep8 naming compliance. - **BREAKING**: Migrate to `sqlalchemy` 2.0. - **BREAKING**: Make `sqlalchemy` an optional dependency. diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index 241efadcd..a3f132ea9 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -2,8 +2,12 @@ from attrs import define +from griptape.config.base_driver_config import BaseDriverConfig +from griptape.config.events_config import EventsConfig from griptape.mixins.serializable_mixin import SerializableMixin @define -class BaseConfig(SerializableMixin, ABC): ... +class BaseConfig(SerializableMixin, ABC): + drivers: BaseDriverConfig + events: EventsConfig diff --git a/griptape/config/base_driver_config.py b/griptape/config/base_driver_config.py index 46ff181d3..df32d382e 100644 --- a/griptape/config/base_driver_config.py +++ b/griptape/config/base_driver_config.py @@ -5,6 +5,8 @@ from attrs import define, field +from griptape.mixins import SerializableMixin + if TYPE_CHECKING: from griptape.drivers import ( BaseAudioTranscriptionDriver, @@ -19,7 +21,7 @@ @define -class BaseDriverConfig(ABC): +class BaseDriverConfig(ABC, SerializableMixin): prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/config.py b/griptape/config/config.py index f325d1265..71edef8e6 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,17 +1,14 @@ -from attrs import define - -from griptape.config.base_config import BaseConfig -from griptape.config.base_driver_config import BaseDriverConfig -from griptape.mixins.event_publisher_mixin import EventPublisherMixin +from attrs import Factory, define, field +from .base_config import BaseConfig +from .events_config import EventsConfig from .openai_driver_config import OpenAiDriverConfig @define -class _Config(BaseConfig, EventPublisherMixin): - drivers: BaseDriverConfig +class _Config(BaseConfig): + drivers: OpenAiDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) + events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) -Config = _Config( - drivers=OpenAiDriverConfig(), -) +Config = _Config() From e80f78e3c0c95087e7a1dcec5917d387f19dcd7a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:22:14 -0700 Subject: [PATCH 11/40] Add back util fields for Agent --- griptape/structures/agent.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 31e0a424f..f31e9d2eb 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -2,16 +2,18 @@ from typing import TYPE_CHECKING, Callable, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable +from griptape.config import Config from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask if TYPE_CHECKING: from griptape.artifacts import BaseArtifact + from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask from griptape.tools import BaseTool @@ -21,6 +23,8 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) + stream: bool = field(default=False, kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -33,11 +37,19 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() + self.prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: - task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) + task = ToolkitTask( + self.input, + prompt_driver=self.prompt_driver, + tools=self.tools, + max_meta_memory_entries=self.max_meta_memory_entries, + ) else: - task = PromptTask(self.input, max_meta_memory_entries=self.max_meta_memory_entries) + task = PromptTask( + self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries + ) self.add_task(task) From 91619a5b1031654bd596f66ac21ba495a366c095 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:23:44 -0700 Subject: [PATCH 12/40] Revert changelog replaces --- CHANGELOG.md | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b224bdd4..ea88983f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,8 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseVectorStoreDriver.load_artifacts` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.upsert_vector` optional arguments are now keyword-only arguments. - **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments. -- **BREAKING**: `EventListener.events.publish_event`'s `flush` argument is now a keyword-only argument. -- **BREAKING**: `BaseEventListenerDriver.events.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument. +- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument. - **BREAKING**: Renamed `DummyException` to `DummyError` for pep8 naming compliance. - **BREAKING**: Migrate to `sqlalchemy` 2.0. - **BREAKING**: Make `sqlalchemy` an optional dependency. @@ -112,7 +112,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `GoogleWebSearchDriver` to web search with the Google Customsearch API. - `DuckDuckGoWebSearchDriver` to web search with the DuckDuckGo search SDK. - `ProxyWebScraperDriver` to web scrape using proxies. -- Parameter `session` on `AmazonBedrockDriverConfig`. +- Parameter `session` on `AmazonBedrockStructureConfig`. - Parameter `meta` on `TextArtifact`. - `VectorStoreClient` improvements: - `VectorStoreClient.query_params` dict for custom query params. @@ -155,7 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: removed `VectorStoreClient.top_n` and `VectorStoreClient.namespace` in favor of `VectorStoreClient.query_params`. - **BREAKING**: All `futures_executor` fields renamed to `futures_executor_fn` and now accept callables instead of futures; wrapped all future `submit` calls with the `with` block to address future executor shutdown issues. - `GriptapeCloudKnowledgeBaseClient` migrated to `/search` api. -- Default Prompt Driver model in `GoogleDriverConfig` to `gemini-1.5-pro`. +- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. ### Fixed - `CoherePromptDriver` to properly handle empty history. @@ -175,7 +175,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Base Tool schema so that `input` is optional when no Tool Activity schema is set. - Tool Task system prompt for better results with lower-end models. -- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicDriverConfig` and `AmazonBedrockDriverConfig.` +- Default Prompt Driver model to Claude 3.5 Sonnet in `AnthropicStructureConfig` and `AmazonBedrockStructureConfig.` ## [0.27.0] - 2024-06-19 @@ -186,7 +186,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BaseTask.add_parents()` to add multiple parent tasks to a child task. - `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. - `CohereEmbeddingDriver` for using Cohere's embeddings API. -- `CohereDriverConfig` for providing Structures with quick Cohere configuration. +- `CohereStructureConfig` for providing Structures with quick Cohere configuration. - `AmazonSageMakerJumpstartPromptDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. - `AmazonSageMakerJumpstartEmbeddingDriver.custom_attributes` for setting custom attributes when invoking an endpoint. @@ -252,7 +252,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.26.0] - 2024-06-04 ### Added -- `AzureOpenAiDriverConfig` for providing Structures with all Azure OpenAI Driver configuration. +- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. - `AudioLoader` for loading audio content into an `AudioArtifact`. - `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. @@ -263,8 +263,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. -- **BREAKING**: Removed `DriverConfig.global_drivers`. Pass Drivers directly to the Driver Config instead. -- **BREAKING**: Removed `DriverConfig.task_memory` in favor of configuring directly on the Structure. +- **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. +- **BREAKING**: Removed `StructureConfig.task_memory` in favor of configuring directly on the Structure. - **BREAKING**: Updated OpenAI-based image query drivers to remove Vision from the name. - **BREAKING**: `off_prompt` now defaults to `False` on all Tools, making Task Memory something that must be explicitly opted into. - **BREAKING**: `AmazonSageMakerPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`. @@ -293,7 +293,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Default behavior of Event Listener Drivers to batch events. -- Default behavior of OpenAiDriverConfig to utilize `gpt-4o` for prompt_driver. +- Default behavior of OpenAiStructureConfig to utilize `gpt-4o` for prompt_driver. ## [0.25.0] - 2024-05-06 @@ -359,7 +359,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `text-embedding-3-small` and `text-embedding-3-large` models. - `GooglePromptDriver` and `GoogleTokenizer` for use with `gemini-pro`. - `GoogleEmbeddingDriver` for use with `embedding-001`. -- `GoogleDriverConfig` for providing Structures with Google Prompt and Embedding Driver configuration. +- `GoogleStructureConfig` for providing Structures with Google Prompt and Embedding Driver configuration. - Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`. - Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`. - `top_k` and `top_p` parameters in `AnthropicPromptDriver`. @@ -369,7 +369,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura. - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. - `VoyageAiEmbeddingDriver` for use with VoyageAi's embedding models. -- `AnthropicDriverConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. +- `AnthropicStructureConfig` for providing Structures with Anthropic Prompt and VoyageAi Embedding Driver configuration. - `QdrantVectorStoreDriver` to integrate with Qdrant vector databases. ### Fixed @@ -380,9 +380,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`. - **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified - Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`. -- Default model of `OpenAiDriverConfig` to `text-embedding-3-small`. +- Default model of `OpenAiStructureConfig` to `text-embedding-3-small`. - `BaseTextLoader` to accept a `BaseChunker`. -- Default model of `AmazonBedrockDriverConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. +- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. - `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API. - `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256 @@ -391,7 +391,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Deprecation warnings not displaying for `Structure.prompt_driver`, `Structure.embedding_driver`, and `Structure.stream`. - `DummyException` error message not fully displaying. -- `DriverConfig.task_memory` not defaulting to using `DriverConfig.global_drivers` by default. +- `StructureConfig.task_memory` not defaulting to using `StructureConfig.global_drivers` by default. ## [0.23.1] - 2024-03-07 @@ -408,9 +408,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AzureMongoDbVectorStoreDriver` for using CosmosDB with MongoDB vCore API. - `vector_path` field on `MongoDbAtlasVectorStoreDriver`. - `LeonardoImageGenerationDriver` supports image to image generation. -- `OpenAiDriverConfig` for providing Structures with all OpenAi Driver configuration. -- `AmazonBedrockDriverConfig` for providing Structures with all Amazon Bedrock Driver configuration. -- `DriverConfig` for building your own Structure configuration. +- `OpenAiStructureConfig` for providing Structures with all OpenAi Driver configuration. +- `AmazonBedrockStructureConfig` for providing Structures with all Amazon Bedrock Driver configuration. +- `StructureConfig` for building your own Structure configuration. - `JsonExtractionTask` for convenience over using `ExtractionTask` with a `JsonExtractionEngine`. - `CsvExtractionTask` for convenience over using `ExtractionTask` with a `CsvExtractionEngine`. - `OpenAiVisionImageQueryDriver` to support queries on images using OpenAI's Vision model. From 9ebd44e51f321c69b1981b657f4bd433a038518c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 13:25:36 -0700 Subject: [PATCH 13/40] Fix type --- griptape/config/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/griptape/config/config.py b/griptape/config/config.py index 71edef8e6..3920bdbb0 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,13 +1,14 @@ from attrs import Factory, define, field from .base_config import BaseConfig +from .base_driver_config import BaseDriverConfig from .events_config import EventsConfig from .openai_driver_config import OpenAiDriverConfig @define class _Config(BaseConfig): - drivers: OpenAiDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) + drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) From b7e1359d886a84ab8d47d80cae7380558d54e345 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 6 Aug 2024 14:58:53 -0700 Subject: [PATCH 14/40] Revert some of agent test --- tests/unit/structures/test_agent.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 15e1399b6..d82de015c 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -5,13 +5,16 @@ from griptape.rules import Rule, Ruleset from griptape.structures import Agent from griptape.tasks import BaseTask, PromptTask, ToolkitTask +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool class TestAgent: def test_init(self): - agent = Agent(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + driver = MockPromptDriver() + agent = Agent(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + assert agent.prompt_driver is driver assert isinstance(agent.task, PromptTask) assert isinstance(agent.task, PromptTask) assert agent.rulesets[0].name == "TestRuleset" @@ -77,7 +80,7 @@ def test_without_default_task_memory(self): assert agent.tools[0].output_memory is None def test_with_memory(self): - agent = Agent(conversation_memory=ConversationMemory()) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) assert agent.conversation_memory is not None assert len(agent.conversation_memory.runs) == 0 @@ -99,7 +102,7 @@ def test_tasks_initialization(self): assert agent.tasks[0] == task def test_add_task(self): - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) assert len(agent.tasks) == 1 @@ -127,7 +130,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) try: agent.add_tasks(first_task, second_task) @@ -142,7 +145,7 @@ def test_add_tasks(self): assert True def test_prompt_stack_without_memory(self): - agent = Agent(conversation_memory=None, rules=[Rule("test")]) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") @@ -159,7 +162,7 @@ def test_prompt_stack_without_memory(self): assert len(task1.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - agent = Agent(conversation_memory=ConversationMemory(), rules=[Rule("test")]) + agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory(), rules=[Rule("test")]) task1 = PromptTask("test") @@ -177,7 +180,7 @@ def test_prompt_stack_with_memory(self): def test_run(self): task = PromptTask("test") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) assert task.state == BaseTask.State.PENDING @@ -189,7 +192,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) agent._execution_args = ("test1", "test2") @@ -202,7 +205,7 @@ def test_run_with_args(self): def test_context(self): task = PromptTask("test prompt") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) @@ -214,7 +217,7 @@ def test_context(self): def finished_tasks(self): task = PromptTask("test prompt") - agent = Agent() + agent = Agent(prompt_driver=MockPromptDriver()) agent.add_task(task) @@ -224,4 +227,4 @@ def finished_tasks(self): def test_fail_fast(self): with pytest.raises(ValueError): - Agent(fail_fast=True) + Agent(prompt_driver=MockPromptDriver(), fail_fast=True) From 97564a22d03d550373d2b02802e346979eaf1f5e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 11:42:27 -0700 Subject: [PATCH 15/40] Add logging module to config --- griptape/config/config.py | 2 ++ griptape/config/logging_config.py | 24 ++++++++++++++++++++ griptape/structures/structure.py | 22 ------------------ griptape/tasks/actions_subtask.py | 17 ++++++++------ griptape/tasks/base_audio_generation_task.py | 8 +++++-- griptape/tasks/base_audio_input_task.py | 8 +++++-- griptape/tasks/base_image_generation_task.py | 7 +++++- griptape/tasks/base_multi_text_input_task.py | 8 +++++-- griptape/tasks/base_task.py | 5 +++- griptape/tasks/base_text_input_task.py | 8 +++++-- griptape/tasks/prompt_task.py | 7 ++++-- 11 files changed, 75 insertions(+), 41 deletions(-) create mode 100644 griptape/config/logging_config.py diff --git a/griptape/config/config.py b/griptape/config/config.py index 3920bdbb0..8d29b5a0f 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -3,6 +3,7 @@ from .base_config import BaseConfig from .base_driver_config import BaseDriverConfig from .events_config import EventsConfig +from .logging_config import LoggingConfig from .openai_driver_config import OpenAiDriverConfig @@ -10,6 +11,7 @@ class _Config(BaseConfig): drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) + logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) Config = _Config() diff --git a/griptape/config/logging_config.py b/griptape/config/logging_config.py new file mode 100644 index 000000000..0c0fcc020 --- /dev/null +++ b/griptape/config/logging_config.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import logging + +from attrs import define, field +from rich.logging import RichHandler + + +@define +class LoggingConfig: + logger_name: str = field(default="griptape", kw_only=True) + logger_level: int = field( + default=logging.INFO, + kw_only=True, + on_setattr=lambda self, _, value: logging.getLogger(self.logger_name).setLevel(value), + ) + + def __attrs_post_init__(self) -> None: + logger = logging.getLogger(self.logger_name) + + logger.propagate = False + logger.setLevel(self.logger_level) + + logger.handlers = [RichHandler(show_time=True, show_path=False)] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 4db8228e3..49197592f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -1,13 +1,10 @@ from __future__ import annotations -import logging import uuid from abc import ABC, abstractmethod -from logging import Logger from typing import TYPE_CHECKING, Any, Optional from attrs import Attribute, Factory, define, field -from rich.logging import RichHandler from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable @@ -35,14 +32,10 @@ @define class Structure(ABC): - LOGGER_NAME = "griptape" - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) rulesets: list[Ruleset] = field(factory=list, kw_only=True) rules: list[Rule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) - custom_logger: Optional[Logger] = field(default=None, kw_only=True) - logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory(lambda: ConversationMemory()), kw_only=True, @@ -55,7 +48,6 @@ class Structure(ABC): meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) fail_fast: bool = field(default=True, kw_only=True) _execution_args: tuple = () - _logger: Optional[Logger] = None @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: @@ -88,20 +80,6 @@ def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: def execution_args(self) -> tuple: return self._execution_args - @property - def logger(self) -> Logger: - if self.custom_logger: - return self.custom_logger - else: - if self._logger is None: - self._logger = logging.getLogger(self.LOGGER_NAME) - - self._logger.propagate = False - self._logger.level = self.logger_level - - self._logger.handlers = [RichHandler(show_time=True, show_path=False)] - return self._logger - @property def input_task(self) -> Optional[BaseTask]: return self.tasks[0] if self.tasks else None diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..e3c2aeb12 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging import re from typing import TYPE_CHECKING, Callable, Optional @@ -18,6 +19,8 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory +logger = logging.getLogger(Config.logging.logger_name) + @define class ActionsSubtask(BaseTask): @@ -86,7 +89,7 @@ def attach_to(self, parent_task: BaseTask) -> None: else: self.__init_from_artifacts(self.input) except Exception as e: - self.structure.logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) + logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) @@ -109,7 +112,7 @@ def before_run(self) -> None: *([f"\nThought: {self.thought}"] if self.thought is not None else []), f"\nActions: {self.actions_to_json()}", ] - self.structure.logger.info("".join(parts)) + logger.info("".join(parts)) def run(self) -> BaseArtifact: try: @@ -128,7 +131,7 @@ def run(self) -> BaseArtifact: actions_output.append(output) self.output = ListArtifact(actions_output) except Exception as e: - self.structure.logger.exception("Subtask %s\n%s", self.id, e) + logger.exception("Subtask %s\n%s", self.id, e) self.output = ErrorArtifact(str(e), exception=e) if self.output is not None: @@ -169,7 +172,7 @@ def after_run(self) -> None: subtask_actions=self.actions_to_dicts(), ), ) - self.structure.logger.info("Subtask %s\nResponse: %s", self.id, response) + logger.info("Subtask %s\nResponse: %s", self.id, response) def actions_to_dicts(self) -> list[dict]: json_list = [] @@ -257,7 +260,7 @@ def __parse_actions(self, actions_matches: list[str]) -> None: self.actions = [self.__process_action_object(action_object) for action_object in actions_list] except json.JSONDecodeError as e: - self.structure.logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e) @@ -314,10 +317,10 @@ def __validate_action(self, action: ToolAction) -> None: if activity_schema: activity_schema.validate(action.input) except schema.SchemaError as e: - self.structure.logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Activity input JSON validation error: {e}", exception=e) except SyntaxError as e: - self.structure.logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Syntax error: {e}", exception=e) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index d2657561d..4d9d82362 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -1,21 +1,25 @@ from __future__ import annotations +import logging from abc import ABC from attrs import define +from griptape.config import Config from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index 517c03a15..febd3f508 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -1,14 +1,18 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.config.config import Config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseAudioInputTask(RuleMixin, BaseTask, ABC): @@ -30,9 +34,9 @@ def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> N def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index d32e8f142..afbc2c05e 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os from abc import ABC from pathlib import Path @@ -7,6 +8,7 @@ from attrs import Attribute, define, field +from griptape.config import Config from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -16,6 +18,9 @@ from griptape.artifacts import MediaArtifact +logger = logging.getLogger(Config.logging.logger_name) + + @define class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. @@ -60,5 +65,5 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets def _read_from_file(self, path: str) -> MediaArtifact: - self.structure.logger.info("Reading image from %s", os.path.abspath(path)) + logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index a0d8cb9ac..6962098ca 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact +from griptape.config import Config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): @@ -48,9 +52,9 @@ def before_run(self) -> None: super().before_run() joined_input = "\n".join([i.to_text() for i in self.input]) - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..cdaf1b032 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import uuid from abc import ABC, abstractmethod from concurrent import futures @@ -16,6 +17,8 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseTask(ABC): @@ -159,7 +162,7 @@ def execute(self) -> Optional[BaseArtifact]: self.after_run() except Exception as e: - self.structure.logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) + logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) self.output = ErrorArtifact(str(e), exception=e) finally: diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 90f60efcd..16f8c705c 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts import TextArtifact +from griptape.config import Config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(Config.logging.logger_name) + @define class BaseTextInputTask(RuleMixin, BaseTask, ABC): @@ -36,9 +40,9 @@ def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 19580b642..3769f26dc 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Callable, Optional from attrs import Factory, define, field @@ -14,6 +15,8 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver +logger = logging.getLogger(Config.logging.logger_name) + @define class PromptTask(RuleMixin, BaseTask): @@ -65,12 +68,12 @@ def default_system_template_generator(self, _: PromptTask) -> str: def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) def run(self) -> BaseArtifact: message = self.prompt_driver.run(self.prompt_stack) From 4220e0f1392f33463f1f43c82b2ea9f999f718e8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:56:48 -0700 Subject: [PATCH 16/40] Fix bad rebase --- griptape/config/base_config.py | 9 +++++---- griptape/config/config.py | 2 -- griptape/drivers/prompt/base_prompt_driver.py | 2 -- griptape/tasks/actions_subtask.py | 1 + griptape/tasks/base_task.py | 1 + griptape/utils/stream.py | 2 -- tests/unit/drivers/prompt/test_base_prompt_driver.py | 3 +-- tests/unit/tasks/test_base_task.py | 2 +- 8 files changed, 9 insertions(+), 13 deletions(-) diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index a3f132ea9..9209aa4a4 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -2,12 +2,13 @@ from attrs import define -from griptape.config.base_driver_config import BaseDriverConfig -from griptape.config.events_config import EventsConfig from griptape.mixins.serializable_mixin import SerializableMixin +from .base_driver_config import BaseDriverConfig +from .logging_config import LoggingConfig -@define + +@define(kw_only=True) class BaseConfig(SerializableMixin, ABC): drivers: BaseDriverConfig - events: EventsConfig + logging: LoggingConfig diff --git a/griptape/config/config.py b/griptape/config/config.py index 8d29b5a0f..d81a8974b 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -2,7 +2,6 @@ from .base_config import BaseConfig from .base_driver_config import BaseDriverConfig -from .events_config import EventsConfig from .logging_config import LoggingConfig from .openai_driver_config import OpenAiDriverConfig @@ -10,7 +9,6 @@ @define class _Config(BaseConfig): drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) - events: EventsConfig = field(default=Factory(lambda: EventsConfig()), kw_only=True) logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index b6c28560b..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -113,8 +113,6 @@ def __process_run(self, prompt_stack: PromptStack) -> Message: return result def __process_stream(self, prompt_stack: PromptStack) -> Message: - from griptape.config import Config - delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index e3c2aeb12..2f199e368 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -11,6 +11,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction +from griptape.config import Config from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index cdaf1b032..c42f73629 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -10,6 +10,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact +from griptape.config import Config from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index cb5266378..87cb9dec8 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -56,8 +56,6 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args) -> None: - from griptape.config import Config - def event_handler(event: BaseEvent) -> None: self._event_queue.put(event) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 248c259e5..c30acdec4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -28,8 +28,7 @@ def test_run_via_pipeline_retries_failure(self, mock_config): def test_run_via_pipeline_publishes_events(self, mocker): mock_publish_event = mocker.patch.object(_EventBus, "publish_event") - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) pipeline.run() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d22ef35f7..d4e0ce23d 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -18,7 +18,7 @@ def task(self): agent = Agent( tools=[MockTool()], ) - Config.event_listeners = [EventListener(handler=Mock())] + EventBus.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) From f8a616fcc93d35c4e9c10d5e8131bd64d57fb68c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:47:58 -0700 Subject: [PATCH 17/40] Update docs --- docs/examples/talk-to-a-video.md | 5 +- .../drivers/embedding-drivers.md | 11 +- .../drivers/event-listener-drivers.md | 13 +- docs/griptape-framework/structures/config.md | 123 ++++++++---------- .../structures/task-memory.md | 10 +- .../official-tools/rest-api-client.md | 15 ++- 6 files changed, 83 insertions(+), 94 deletions(-) diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index 310b6d407..cf41dea0f 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,9 +7,11 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import GoogleDriverConfig +from griptape.config import Config import google.generativeai as genai +Config.drivers = GoogleDriverConfig() + video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": time.sleep(2) @@ -19,7 +21,6 @@ if video_file.state.name == "FAILED": raise ValueError(video_file.state.name) agent = Agent( - config=GoogleDriverConfig(), input=[ GenericArtifact(video_file), TextArtifact("Answer this question regarding the video: {{ args[0] }}"), diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 3c81cf8a9..7a8fd96a1 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -220,14 +220,15 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config -agent = Agent( - tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], - config=DriverConfig( +Config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), embedding=VoyageAiEmbeddingDriver(), - ), +) + +agent = Agent( + tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 8d4f521aa..20ae045f4 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -123,7 +123,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -132,17 +132,18 @@ from griptape.events import ( from griptape.rules import Rule from griptape.structures import Agent +Config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver( + model="gpt-3.5-turbo", temperature=0.7 + ) +) + agent = Agent( rules=[ Rule( value="You will be provided with a text, and your task is to extract the airport codes from it." ) ], - config=DriverConfig( - prompt=OpenAiChatPromptDriver( - model="gpt-3.5-turbo", temperature=0.7 - ) - ), event_listeners=[ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 13b6c001a..b4c928ff7 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -17,13 +17,11 @@ The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config. ```python from griptape.structures import Agent -from griptape.config import OpenAiDriverConfig +from griptape.config import OpenAiDriverConfig, Config -agent = Agent( - config=OpenAiDriverConfig() -) +Config.drivers = OpenAiDriverConfig() -agent = Agent() # This is equivalent to the above +agent = Agent() ``` #### Azure OpenAI @@ -33,18 +31,14 @@ The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_dr ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiDriverConfig - -agent = Agent( - config=AzureOpenAiDriverConfig( - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], - api_key=os.environ["AZURE_OPENAI_API_KEY_3"] - ).merge_config({ - "image_query": { - "azure_deployment": "gpt-4o", - }, - }), +from griptape.config import AzureOpenAiDriverConfig, Config + +Config.drivers = AzureOpenAiDriverConfig( + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], + api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) + +agent = Agent() ``` #### Amazon Bedrock @@ -54,17 +48,17 @@ The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedroc import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig - -agent = Agent( - config=AmazonBedrockDriverConfig( - session=boto3.Session( - region_name=os.environ["AWS_DEFAULT_REGION"], - aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - ) +from griptape.config import AmazonBedrockDriverConfig, Config + +Config.drivers = AmazonBedrockDriverConfig( + session=boto3.Session( + region_name=os.environ["AWS_DEFAULT_REGION"], + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], ) ) + +agent = Agent() ``` #### Google @@ -72,11 +66,11 @@ The [Google Driver Config](../../reference/griptape/config/google_driver_config. ```python from griptape.structures import Agent -from griptape.config import GoogleDriverConfig +from griptape.config import GoogleDriverConfig, Config -agent = Agent( - config=GoogleDriverConfig() -) +Config.drivers = GoogleDriverConfig() + +agent = Agent() ``` #### Anthropic @@ -90,11 +84,11 @@ The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_c ```python from griptape.structures import Agent -from griptape.config import AnthropicDriverConfig +from griptape.config import AnthropicDriverConfig, Config -agent = Agent( - config=AnthropicDriverConfig() -) +Config.drivers = AnthropicDriverConfig() + +agent = Agent() ``` #### Cohere @@ -103,10 +97,12 @@ The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config. ```python import os -from griptape.config import CohereDriverConfig +from griptape.config import CohereDriverConfig, Config from griptape.structures import Agent -agent = Agent(config=CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])) +Config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) + +agent = Agent() ``` ### Custom Configs @@ -118,49 +114,38 @@ This approach ensures that you are informed through clear error messages if you ```python import os from griptape.structures import Agent -from griptape.config import DriverConfig +from griptape.config import DriverConfig, Config from griptape.drivers import AnthropicPromptDriver -agent = Agent( - config=DriverConfig( - prompt=AnthropicPromptDriver( - model="claude-3-sonnet-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) - ), +Config.drivers = DriverConfig( + prompt=AnthropicPromptDriver( + model="claude-3-sonnet-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], + ) ) + + +agent = Agent() ``` ### Loading/Saving Configs -Configuration classes in Griptape offer utility methods for loading, saving, and merging configurations, streamlining the management of complex setups. - ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig -from griptape.drivers import AmazonBedrockCohereEmbeddingDriver +from griptape.config import AmazonBedrockDriverConfig, Config custom_config = AmazonBedrockDriverConfig() -custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() -custom_config.merge_config( - { - "embedding": { - "base_url": None, - "model": "text-embedding-3-small", - "organization": None, - "type": "OpenAiEmbeddingDriver", - }, - } -) -serialized_config = custom_config.to_json() -deserialized_config = AmazonBedrockDriverConfig.from_json(serialized_config) - -agent = Agent( - config=deserialized_config.merge_config({ - "prompt": { - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - }, - }), -) +dict_config = custom_config.to_dict() +# Use OpenAi for embeddings +dict_config["embedding"] = { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", +} +custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) + +Config.drivers = custom_config + +agent = Agent() ``` - diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 3184c4096..49d6b28cf 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - OpenAiDriverConfig, + Config, OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -220,12 +220,13 @@ from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent from griptape.tools import FileManager, TaskMemoryClient, WebScraper +Config.drivers = OpenAiDriverConfig( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), +) + vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) agent = Agent( - config=OpenAiDriverConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), - ), task_memory=TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( @@ -233,7 +234,6 @@ agent = Agent( retrieval_stage=RetrievalRagStage( retrieval_modules=[ VectorStoreRetrievalRagModule( - vector_store_driver=vector_store_driver, query_params={ "namespace": "griptape", diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 675f77b6e..a73f6fa57 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -14,7 +14,14 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import DriverConfig +from griptape.config import Config + +Config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver( + model="gpt-4o", + temperature=0.1 + ), +) posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", @@ -117,12 +124,6 @@ posts_client = RestApiClient( pipeline = Pipeline( conversation_memory=ConversationMemory(), - config=DriverConfig( - prompt=OpenAiChatPromptDriver( - model="gpt-4o", - temperature=0.1 - ), - ), ) pipeline.add_tasks( From 95d83b5634409b95a76480480213aa23b6a5451f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 18/40] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d8bf2e72..6748299d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From 951a4ed1fb163a47c5330d26de2dcc7c704b0e1b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:11:00 -0700 Subject: [PATCH 19/40] Update docs --- .../drivers/event-listener-drivers.md | 89 +++++++++++-------- docs/griptape-framework/misc/events.md | 14 +-- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..db02cd77a 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,26 +14,27 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( - handler=lambda event: { # You can optionally use the handler to transform the event payload before sending it to the Driver - "event": event.to_dict(), - }, driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -83,23 +84,26 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -128,10 +132,23 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.rules import Rule from griptape.structures import Agent +EventBus.add_event_listeners( + [ + EventListener( + event_types=[FinishStructureRunEvent], + driver=AwsIotCoreEventListenerDriver( + topic=os.environ["AWS_IOT_CORE_TOPIC"], + iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], + ), + ), + ] +) + agent = Agent( rules=[ Rule( @@ -143,15 +160,6 @@ agent = Agent( model="gpt-3.5-turbo", temperature=0.7 ) ), - event_listeners=[ - EventListener( - event_types=[FinishStructureRunEvent], - driver=AwsIotCoreEventListenerDriver( - topic=os.environ["AWS_IOT_CORE_TOPIC"], - iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], - ), - ), - ], ) agent.run("I want to fly from Orlando to Boston") @@ -171,18 +179,19 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], # By default, GriptapeCloudEventListenerDriver uses the api key provided # in the GT_CLOUD_API_KEY environment variable. driver=GriptapeCloudEventListenerDriver(), ), - ], + ] ) agent.run( @@ -201,20 +210,23 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=WebhookEventListenerDriver( webhook_url=os.environ["WEBHOOK_URL"], ), ), - ], + ] ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` ### Pusher @@ -229,12 +241,13 @@ import os from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, - FinishStructureRunEvent + FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=PusherEventListenerDriver( @@ -250,6 +263,8 @@ agent = Agent( ], ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..23ebcdc2a 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -30,7 +30,7 @@ from griptape.events import ( def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +EventBus.add_event_listeners([ EventListener( handler, event_types=[ @@ -42,7 +42,7 @@ EventBus.event_listeners=[ FinishPromptEvent, ], ) - ] + ]) agent = Agent() @@ -140,12 +140,12 @@ from griptape.drivers import OpenAiChatPromptDriver -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) -] +]) pipeline = Pipeline( config=OpenAiStructureConfig( @@ -194,12 +194,12 @@ from griptape.structures import Agent token_counter = utils.TokenCounter() -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: token_counter.add_tokens(e.token_count), event_types=[StartPromptEvent, FinishPromptEvent], ) -] +]) def count_tokens(e: BaseEvent): if isinstance(e, StartPromptEvent) or isinstance(e, FinishPromptEvent): @@ -248,7 +248,7 @@ from griptape.structures import Agent from griptape.events import BaseEvent, StartPromptEvent, EventListener, EventBus -EventBus.event_listeners = [EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])] +EventBus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): From 025437b5f0f376318ce15fa5b9111a89f4608484 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:03:17 -0700 Subject: [PATCH 20/40] Make event listeners private --- griptape/events/event_bus.py | 19 +++++++++++----- tests/unit/conftest.py | 4 ++-- tests/unit/events/test_event_bus.py | 2 +- tests/unit/events/test_event_listener.py | 28 +++++++++++++----------- tests/unit/tasks/test_base_task.py | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..6ffd65550 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -10,7 +10,11 @@ @define class _EventBus: - event_listeners: list[EventListener] = field(factory=list, kw_only=True) + _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + + @property + def event_listeners(self) -> list[EventListener]: + return self._event_listeners def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -20,18 +24,21 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self.event_listeners: - self.event_listeners.append(event_listener) + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self.event_listeners: - self.event_listeners.remove(event_listener) + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self.event_listeners: + for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def clear_event_listeners(self) -> None: + self._event_listeners.clear() + EventBus = _EventBus() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7a73b041f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) def event_bus(): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() yield EventBus - EventBus.event_listeners = [] + EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..d237bb3b4 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -35,7 +35,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..f3d9823d3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,17 +59,19 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), - ] + EventBus.add_event_listeners( + [ + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + ] + ) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -87,7 +89,7 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d6e4da8b6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), From 0f193854cc4b4230301b22578d5e14b42c1e72f0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:17:34 -0700 Subject: [PATCH 21/40] Rename EventBus to event_bus --- CHANGELOG.md | 4 +-- .../drivers/event-listener-drivers.md | 24 +++++++-------- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../drivers/prompt/test_base_prompt_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 6 ++-- 23 files changed, 109 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6748299d0..7a95701c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index db02cd77a..c3c92cfe1 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,12 +14,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -84,12 +84,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -132,12 +132,12 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -179,11 +179,11 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -210,11 +210,11 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -242,11 +242,11 @@ from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 23ebcdc2a..b3f4a77fd 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.add_event_listeners([ +event_bus.add_event_listeners([ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 6ffd65550..a956f7deb 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -41,4 +41,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..d68457ebc 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,7 +28,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -257,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -269,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..d600c80a5 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..ade656f87 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..fd64a0f52 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -61,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a73b041f..e462ede90 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.events import EventBus +from griptape.events import event_bus @pytest.fixture(autouse=True) -def event_bus(): - EventBus.clear_event_listeners() +def mock_event_bus(): + event_bus.clear_event_listeners() - yield EventBus + yield event_bus - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..6fcab26e5 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..52b7d5c0d 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _EventBus +from griptape.events.event_bus import _event_bus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_EventBus, "publish_event") + mock_publish_event = mocker.patch.object(_event_bus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index d237bb3b4..7eb87036a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + event_bus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f3d9823d3..50763e0c3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.add_event_listeners( + event_bus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -89,25 +89,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d6e4da8b6..aa402bb48 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), @@ -118,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From 729f3aad962e569321c0d263dc919b9972c6e2b2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:29:50 -0700 Subject: [PATCH 22/40] Fix doc --- docs/griptape-framework/misc/events.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b3f4a77fd..ebab3c460 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -80,10 +80,11 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -event_bus.event_listeners=[ +event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] +) agent = Agent() From c391b1459042f4ad9f621ced9df052a3e0d41f83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:05:52 -0700 Subject: [PATCH 23/40] Fix test --- tests/unit/drivers/prompt/test_base_prompt_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 52b7d5c0d..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _event_bus +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_event_bus, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) From 7baefacc6853fc6bc7acb405f615f57b74f4f4eb Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:12:00 -0700 Subject: [PATCH 24/40] Rename event bus --- CHANGELOG.md | 4 +-- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 8 ++--- 21 files changed, 96 insertions(+), 96 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea88983f3..f338ff961 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b7f118d98..d37a73663 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..c0881503d 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -34,4 +34,4 @@ def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: event_listener.publish_event(event, flush=flush) -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 49197592f..6fea4d2e6 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -18,7 +18,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -180,7 +180,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -192,7 +192,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 2f199e368..ccbf5dbb1 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -12,7 +12,7 @@ from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction from griptape.config import Config -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -95,7 +95,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -161,7 +161,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index c42f73629..2397fbfd0 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -11,7 +11,7 @@ from griptape.artifacts import ErrorArtifact from griptape.config import Config -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -131,7 +131,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -143,7 +143,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 87cb9dec8..efca5c5b8 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -63,8 +63,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9207bbc1c..01af02573 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,17 +1,17 @@ import pytest from griptape.config import Config -from griptape.events import EventBus +from griptape.events import event_bus from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) -def event_bus(): - EventBus.event_listeners = [] +def mock_event_bus(): + event_bus.event_listeners = [] - yield EventBus + yield event_bus - EventBus.event_listeners = [] + event_bus.event_listeners = [] @pytest.fixture(autouse=True) diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 29aecfdf9..61ef3aa53 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..97aaa239b 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + event_bus.event_listeners = [EventListener(handler=mock_handler)] mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 038cb4508..713e5ce42 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -39,7 +39,7 @@ def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + event_bus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -60,7 +60,7 @@ def test_typed_listeners(self, pipeline, mock_config): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ + event_bus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -88,25 +88,25 @@ def test_typed_listeners(self, pipeline, mock_config): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + event_bus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d4e0ce23d..4dfc890c9 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -14,11 +14,11 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + event_bus.event_listeners = [EventListener(handler=Mock())] agent = Agent( tools=[MockTool()], ) - EventBus.event_listeners = [EventListener(handler=Mock())] + event_bus.event_listeners = [EventListener(handler=Mock())] agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -115,4 +115,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From feec94bc7bde4053db08bc19a0632cc89afc3393 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:27:23 -0700 Subject: [PATCH 25/40] Rename Config to config, fix tests --- docs/examples/talk-to-a-video.md | 4 +- .../drivers/embedding-drivers.md | 4 +- .../drivers/event-listener-drivers.md | 4 +- docs/griptape-framework/structures/config.md | 44 +++++++++---------- .../structures/task-memory.md | 4 +- .../official-tools/rest-api-client.md | 4 +- griptape/config/__init__.py | 4 +- griptape/config/config.py | 2 +- .../audio/audio_transcription_engine.py | 4 +- .../engines/audio/text_to_speech_engine.py | 4 +- .../extraction/base_extraction_engine.py | 4 +- .../image/base_image_generation_engine.py | 4 +- .../engines/image_query/image_query_engine.py | 4 +- .../response/prompt_response_rag_module.py | 4 +- .../vector_store_retrieval_rag_module.py | 4 +- .../engines/summary/prompt_summary_engine.py | 4 +- .../structure/base_conversation_memory.py | 6 +-- .../structure/summary_conversation_memory.py | 4 +- .../task/storage/text_artifact_storage.py | 4 +- griptape/structures/agent.py | 4 +- griptape/structures/structure.py | 10 ++--- griptape/tasks/actions_subtask.py | 4 +- griptape/tasks/base_audio_generation_task.py | 4 +- griptape/tasks/base_audio_input_task.py | 4 +- griptape/tasks/base_image_generation_task.py | 4 +- griptape/tasks/base_multi_text_input_task.py | 4 +- griptape/tasks/base_task.py | 4 +- griptape/tasks/base_text_input_task.py | 4 +- griptape/tasks/prompt_task.py | 6 +-- griptape/utils/chat.py | 8 ++-- griptape/utils/stream.py | 4 +- tests/mocks/docker/fake_api.py | 8 ++-- tests/unit/conftest.py | 10 ++--- tests/unit/tasks/test_base_task.py | 2 +- tests/unit/utils/test_stream.py | 6 +-- tests/utils/structure_tester.py | 4 +- 36 files changed, 103 insertions(+), 103 deletions(-) diff --git a/docs/examples/talk-to-a-video.md b/docs/examples/talk-to-a-video.md index cf41dea0f..20ad9952e 100644 --- a/docs/examples/talk-to-a-video.md +++ b/docs/examples/talk-to-a-video.md @@ -7,10 +7,10 @@ import time from griptape.structures import Agent from griptape.tasks import PromptTask from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config import google.generativeai as genai -Config.drivers = GoogleDriverConfig() +config.drivers = GoogleDriverConfig() video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 7a8fd96a1..f210f50b7 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -220,9 +220,9 @@ from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, ) -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver(model="gpt-4o"), embedding=VoyageAiEmbeddingDriver(), ) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 4f1eeb391..e4e815709 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -127,7 +127,7 @@ The [AwsIotCoreEventListenerDriver](../../reference/griptape/drivers/event_liste ```python import os -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import ( EventListener, @@ -137,7 +137,7 @@ from griptape.events import ( from griptape.rules import Rule from griptape.structures import Agent -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver( model="gpt-3.5-turbo", temperature=0.7 ) diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index b4c928ff7..33fa27798 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -13,27 +13,27 @@ Griptape provides predefined [DriverConfig](../../reference/griptape/config/driv #### OpenAI -The [OpenAI Driver Config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. +The [OpenAI Driver config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python from griptape.structures import Agent -from griptape.config import OpenAiDriverConfig, Config +from griptape.config import OpenAiDriverConfig, config -Config.drivers = OpenAiDriverConfig() +config.drivers = OpenAiDriverConfig() agent = Agent() ``` #### Azure OpenAI -The [Azure OpenAI Driver Config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. +The [Azure OpenAI Driver config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python import os from griptape.structures import Agent -from griptape.config import AzureOpenAiDriverConfig, Config +from griptape.config import AzureOpenAiDriverConfig, config -Config.drivers = AzureOpenAiDriverConfig( +config.drivers = AzureOpenAiDriverConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) @@ -42,15 +42,15 @@ agent = Agent() ``` #### Amazon Bedrock -The [Amazon Bedrock Driver Config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Driver config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python import os import boto3 from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig, Config +from griptape.config import AmazonBedrockDriverConfig, config -Config.drivers = AmazonBedrockDriverConfig( +config.drivers = AmazonBedrockDriverConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -62,20 +62,20 @@ agent = Agent() ``` #### Google -The [Google Driver Config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Driver config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python from griptape.structures import Agent -from griptape.config import GoogleDriverConfig, Config +from griptape.config import GoogleDriverConfig, config -Config.drivers = GoogleDriverConfig() +config.drivers = GoogleDriverConfig() agent = Agent() ``` #### Anthropic -The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Driver config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. @@ -84,23 +84,23 @@ The [Anthropic Driver Config](../../reference/griptape/config/anthropic_driver_c ```python from griptape.structures import Agent -from griptape.config import AnthropicDriverConfig, Config +from griptape.config import AnthropicDriverConfig, config -Config.drivers = AnthropicDriverConfig() +config.drivers = AnthropicDriverConfig() agent = Agent() ``` #### Cohere -The [Cohere Driver Config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. +The [Cohere Driver config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python import os -from griptape.config import CohereDriverConfig, Config +from griptape.config import CohereDriverConfig, config from griptape.structures import Agent -Config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) +config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) agent = Agent() ``` @@ -114,10 +114,10 @@ This approach ensures that you are informed through clear error messages if you ```python import os from griptape.structures import Agent -from griptape.config import DriverConfig, Config +from griptape.config import DriverConfig, config from griptape.drivers import AnthropicPromptDriver -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], @@ -132,7 +132,7 @@ agent = Agent() ```python from griptape.structures import Agent -from griptape.config import AmazonBedrockDriverConfig, Config +from griptape.config import AmazonBedrockDriverConfig, config custom_config = AmazonBedrockDriverConfig() dict_config = custom_config.to_dict() @@ -145,7 +145,7 @@ dict_config["embedding"] = { } custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) -Config.drivers = custom_config +config.drivers = custom_config agent = Agent() ``` diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index 49d6b28cf..1fad33856 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,7 @@ In this example, GPT-4 _never_ sees the contents of the page, only that it was s ```python from griptape.artifacts import TextArtifact from griptape.config import ( - Config, OpenAiDriverConfig, + config, OpenAiDriverConfig, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -220,7 +220,7 @@ from griptape.memory.task.storage import TextArtifactStorage from griptape.structures import Agent from griptape.tools import FileManager, TaskMemoryClient, WebScraper -Config.drivers = OpenAiDriverConfig( +config.drivers = OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index a73f6fa57..0151c2efd 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -14,9 +14,9 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient -from griptape.config import Config +from griptape.config import config -Config.drivers = DriverConfig( +config.drivers = DriverConfig( prompt=OpenAiChatPromptDriver( model="gpt-4o", temperature=0.1 diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 7450d7738..b242d80a7 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -9,7 +9,7 @@ from .anthropic_driver_config import AnthropicDriverConfig from .google_driver_config import GoogleDriverConfig from .cohere_driver_config import CohereDriverConfig -from .config import Config +from .config import config __all__ = [ @@ -22,5 +22,5 @@ "AnthropicDriverConfig", "GoogleDriverConfig", "CohereDriverConfig", - "Config", + "config", ] diff --git a/griptape/config/config.py b/griptape/config/config.py index d81a8974b..97d501abb 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -12,4 +12,4 @@ class _Config(BaseConfig): logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) -Config = _Config() +config = _Config() diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 51022e47c..cad8287d5 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,14 +1,14 @@ from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: Config.drivers.audio_transcription), kw_only=True + default=Factory(lambda: config.drivers.audio_transcription), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index a163c36fd..aad45a10a 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: Config.drivers.text_to_speech), kw_only=True + default=Factory(lambda: config.drivers.text_to_speech), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index a1bcbdee2..4b1184e5e 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 921d600c7..9bec68b91 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: Config.drivers.image_generation) + kw_only=True, default=Factory(lambda: config.drivers.image_generation) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index d85e6012d..f2bd99544 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import Config +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -13,7 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: Config.drivers.image_query), kw_only=True) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: config.drivers.image_query), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 8e421d792..9804404fc 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 @@ -17,7 +17,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): answer_token_offset: int = field(default=400) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 4daa10e54..6ce235fa5 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape import utils -from griptape.config import Config +from griptape.config import config from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 1c45fa5ea..82c33a0ad 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -7,7 +7,7 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.common import Message, PromptStack -from griptape.config import Config +from griptape.config import config from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index e7c8ed488..d6e3549af 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import PromptStack -from griptape.config import Config +from griptape.config import config from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: Config.drivers.conversation_memory), kw_only=True + default=Factory(lambda: config.drivers.conversation_memory), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = Config.drivers.prompt + prompt_driver = config.drivers.prompt temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 50be69a61..4263e61c8 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import Message, PromptStack -from griptape.config import Config +from griptape.config import config from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: Config.drivers.prompt)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 460581997..ded114213 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -5,7 +5,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage @@ -16,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index f31e9d2eb..59a865897 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -6,7 +6,7 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable -from griptape.config import Config +from griptape.config import config from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask @@ -24,7 +24,7 @@ class Agent(Structure): default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) stream: bool = field(default=False, kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 6fea4d2e6..b7ca84c4f 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -8,7 +8,7 @@ from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import Config +from griptape.config import config from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -118,10 +118,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=Config.drivers.vector_store, - summary_engine=PromptSummaryEngine(prompt_driver=Config.drivers.prompt), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=Config.drivers.prompt), - json_extraction_engine=JsonExtractionEngine(prompt_driver=Config.drivers.prompt), + vector_store_driver=config.drivers.vector_store, + summary_engine=PromptSummaryEngine(prompt_driver=config.drivers.prompt), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=config.drivers.prompt), + json_extraction_engine=JsonExtractionEngine(prompt_driver=config.drivers.prompt), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index ccbf5dbb1..0f885d260 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -11,7 +11,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.config import Config +from griptape.config import config from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask @@ -20,7 +20,7 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 4d9d82362..00774e0a2 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -5,11 +5,11 @@ from attrs import define -from griptape.config import Config +from griptape.config import config from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index febd3f508..8a470bb85 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -7,11 +7,11 @@ from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.config.config import Config +from griptape.config.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index afbc2c05e..f94ff8de2 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -8,7 +8,7 @@ from attrs import Attribute, define, field -from griptape.config import Config +from griptape.config import config from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -18,7 +18,7 @@ from griptape.artifacts import MediaArtifact -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index 6962098ca..c688a1129 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -7,12 +7,12 @@ from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 2397fbfd0..b3086bebb 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -10,7 +10,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.config import Config +from griptape.config import config from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: @@ -18,7 +18,7 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 16f8c705c..1c9dfc023 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -7,12 +7,12 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.config import Config +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 3769f26dc..a8038832d 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -7,7 +7,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack -from griptape.config import Config +from griptape.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -15,12 +15,12 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver -logger = logging.getLogger(Config.logging.logger_name) +logger = logging.getLogger(config.logging.logger_name) @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: Config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 99b5a7dc3..56b53c0ce 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,15 +25,15 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - from griptape.config import Config + from griptape.config import config - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 def start(self) -> None: - from griptape.config import Config + from griptape.config import config if self.intro_text: self.output_fn(self.intro_text) @@ -44,7 +44,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index efca5c5b8..c5545bc44 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,9 +34,9 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - from griptape.config import Config + from griptape.config import config - if not Config.drivers.prompt.stream: + if not config.drivers.prompt.stream: raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/docker/fake_api.py b/tests/mocks/docker/fake_api.py index 881093057..00e750232 100644 --- a/tests/mocks/docker/fake_api.py +++ b/tests/mocks/docker/fake_api.py @@ -154,7 +154,7 @@ def get_fake_inspect_container(*, tty=False): status_code = 200 response = { "Id": FAKE_CONTAINER_ID, - "Config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, + "config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, "ID": FAKE_CONTAINER_ID, "Image": "busybox:latest", "Name": "foobar", @@ -166,7 +166,7 @@ def get_fake_inspect_container(*, tty=False): "StartedAt": "2013-09-25T14:01:18.869545111+02:00", "Ghost": False, }, - "HostConfig": {"LogConfig": {"Type": "json-file", "Config": {}}}, + "HostConfig": {"LogConfig": {"Type": "json-file", "config": {}}}, "MacAddress": "02:42:ac:11:00:0a", } return status_code, response @@ -179,7 +179,7 @@ def get_fake_inspect_image(): "Parent": "27cf784147099545", "Created": "2013-03-23T22:24:18.818426-07:00", "Container": FAKE_CONTAINER_ID, - "Config": {"Labels": {"bar": "foo"}}, + "config": {"Labels": {"bar": "foo"}}, "ContainerConfig": { "Hostname": "", "User": "", @@ -446,7 +446,7 @@ def get_fake_network_list(): "Driver": "bridge", "EnableIPv6": False, "Internal": False, - "IPAM": {"Driver": "default", "Config": [{"Subnet": "172.17.0.0/16"}]}, + "IPAM": {"Driver": "default", "config": [{"Subnet": "172.17.0.0/16"}]}, "Containers": { FAKE_CONTAINER_ID: { "EndpointID": "ed2419a97c1d99", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 01af02573..8a37f6d28 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,21 +1,21 @@ import pytest -from griptape.config import Config +from griptape.config import config from griptape.events import event_bus from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) def mock_event_bus(): - event_bus.event_listeners = [] + event_bus.clear_event_listeners() yield event_bus - event_bus.event_listeners = [] + event_bus.clear_event_listeners() @pytest.fixture(autouse=True) def mock_config(): - Config.drivers = MockDriverConfig() + config.drivers = MockDriverConfig() - return Config + return config diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 90e826d19..1b45b4e98 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -18,7 +18,7 @@ def task(self): agent = Agent( tools=[MockTool()], ) - event_bus.event_listeners = [EventListener(handler=Mock())] + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index 318f434c3..edd0258f2 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,7 +2,7 @@ import pytest -from griptape.config import Config +from griptape.config import config from griptape.structures import Agent from griptape.utils import Stream @@ -10,11 +10,11 @@ class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - Config.drivers.prompt.stream = request.param + config.drivers.prompt.stream = request.param return Agent() def test_init(self, agent): - if Config.drivers.prompt.stream: + if config.drivers.prompt.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 2b9f83b81..d87fc095e 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -226,9 +226,9 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: - from griptape.config import Config + from griptape.config import config - Config.drivers.prompt = AzureOpenAiChatPromptDriver( + config.drivers.prompt = AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], From 13969c3a71c23ee733f73c5bf115d8d65bdd57e3 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:30:26 -0700 Subject: [PATCH 26/40] Fix doc --- docs/examples/multiple-agent-shared-memory.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 0fe589d7b..30ff03ecc 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -11,7 +11,7 @@ import os from griptape.tools import WebScraper, TaskMemoryClient from griptape.structures import Agent from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver -from griptape.config import AzureOpenAiDriverConfig +from griptape.config import AzureOpenAiDriverConfig, config AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"] AZURE_OPENAI_API_KEY_1 = os.environ["AZURE_OPENAI_API_KEY_1"] @@ -40,7 +40,7 @@ mongo_driver = AzureMongoDbVectorStoreDriver( vector_path=MONGODB_VECTOR_PATH, ) -config = AzureOpenAiDriverConfig( +config.drivers = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store=mongo_driver, embedding=embedding_driver, From f5c42e8944e65127243d94e1def66ceae43a3b3f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:05:50 -0700 Subject: [PATCH 27/40] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a95701c2..d4aaab1a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Global config, `griptape.config.config`, for setting global configuration defaults. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. +- **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 From dd23d895a665a28572b91168887bb600f85e7c4c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 7 Aug 2024 16:23:25 -0700 Subject: [PATCH 28/40] Add global event bus --- CHANGELOG.md | 3 + docs/griptape-framework/misc/events.md | 61 ++++++++++--------- griptape/config/base_structure_config.py | 40 ------------ .../base_audio_transcription_driver.py | 10 +-- .../embedding/base_embedding_driver.py | 4 +- .../base_image_generation_driver.py | 10 +-- .../image_query/base_image_query_driver.py | 10 +-- .../base_conversation_memory_driver.py | 4 +- griptape/drivers/prompt/base_prompt_driver.py | 16 ++--- .../base_text_to_speech_driver.py | 9 +-- .../vector/base_vector_store_driver.py | 4 +- griptape/events/__init__.py | 2 + .../event_bus.py} | 5 +- griptape/mixins/__init__.py | 2 - griptape/structures/structure.py | 12 ++-- griptape/tasks/actions_subtask.py | 6 +- griptape/tasks/base_task.py | 6 +- griptape/utils/stream.py | 9 +-- tests/unit/config/test_structure_config.py | 35 ----------- tests/unit/conftest.py | 12 ++++ .../test_base_audio_transcription_driver.py | 4 +- .../test_base_image_generation_driver.py | 9 +-- .../test_base_image_query_driver.py | 4 +- .../drivers/prompt/test_base_prompt_driver.py | 7 +-- .../test_base_audio_transcription_driver.py | 4 +- tests/unit/events/test_event_bus.py | 45 ++++++++++++++ tests/unit/events/test_event_listener.py | 29 ++++----- tests/unit/mixins/test_events_mixin.py | 59 ------------------ tests/unit/tasks/test_base_task.py | 5 +- 29 files changed, 176 insertions(+), 250 deletions(-) rename griptape/{mixins/event_publisher_mixin.py => events/event_bus.py} (96%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/events/test_event_bus.py delete mode 100644 tests/unit/mixins/test_events_mixin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f9b2e72e8..76f705ddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. ### Changed +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 1f50fd6d0..187321dc6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can use [EventListener](../../reference/griptape/events/event_listener.md)s to listen for events during a Structure's execution. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,15 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, + EventBus ) def handler(event: BaseEvent): print(event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener( handler, event_types=[ @@ -44,7 +43,8 @@ agent = Agent( ], ) ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -69,7 +69,8 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener +from griptape.events import BaseEvent, EventListener, EventBus + def handler1(event: BaseEvent): @@ -79,13 +80,12 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) - -agent = Agent( - event_listeners=[ +EventBus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] -) + +agent = Agent() agent.run("tell me about griptape") ``` @@ -131,7 +131,7 @@ Handler 2 list: - return [ - self.prompt_driver, - self.image_generation_driver, - self.image_query_driver, - self.embedding_driver, - self.vector_store_driver, - self.conversation_memory_driver, - self.text_to_speech_driver, - self.audio_transcription_driver, - ] - - @property - def structure(self) -> Optional[Structure]: - return self._structure - - @structure.setter - def structure(self, structure: Structure) -> None: - if structure != self.structure: - event_publisher_drivers = [ - driver for driver in self.drivers if driver is not None and isinstance(driver, EventPublisherMixin) - ] - - for driver in event_publisher_drivers: - if self._event_listener is not None: - driver.remove_event_listener(self._event_listener) - - self._event_listener = EventListener(structure.publish_event) - for driver in event_publisher_drivers: - driver.add_event_listener(self._event_listener) - - self._structure = structure - def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() merged_config = dict_merge(base_config, config) diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index c81ea1d5b..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import AudioArtifact, TextArtifact @define -class BaseAudioTranscriptionDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - self.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - self.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index 690726060..8998f00e5 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -7,7 +7,7 @@ from attrs import define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import TextArtifact @@ -15,7 +15,7 @@ @define -class BaseEmbeddingDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index f500d6d09..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,22 +5,22 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @define -class BaseImageGenerationDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - self.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - self.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b39f198d4..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,24 +5,24 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @define -class BaseImageQueryDriver(EventPublisherMixin, SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - self.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - self.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py index f13b82c29..1caeb902f 100644 --- a/griptape/drivers/memory/conversation/base_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/base_conversation_memory_driver.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory -class BaseConversationMemoryDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, memory: BaseConversationMemory) -> None: ... diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index e5fd0408d..94e46e75d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,8 +16,8 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from collections.abc import Iterator @@ -26,7 +26,7 @@ @define(kw_only=True) -class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, EventPublishe use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - self.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - self.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - self.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - self.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - self.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index 788d92974..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,23 +5,24 @@ from attrs import define, field +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent -from griptape.mixins import EventPublisherMixin, ExponentialBackoffMixin, SerializableMixin +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @define -class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, EventPublisherMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - self.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - self.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index d1da78188..ed1f2d589 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -10,14 +10,14 @@ from griptape import utils from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.mixins import EventPublisherMixin, SerializableMixin +from griptape.mixins import SerializableMixin if TYPE_CHECKING: from griptape.drivers import BaseEmbeddingDriver @define -class BaseVectorStoreDriver(EventPublisherMixin, SerializableMixin, ABC): +class BaseVectorStoreDriver(SerializableMixin, ABC): DEFAULT_QUERY_COUNT = 5 @dataclass diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 944a309eb..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,6 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -48,4 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", + "EventBus", ] diff --git a/griptape/mixins/event_publisher_mixin.py b/griptape/events/event_bus.py similarity index 96% rename from griptape/mixins/event_publisher_mixin.py rename to griptape/events/event_bus.py index 67a302ed6..9239e66bd 100644 --- a/griptape/mixins/event_publisher_mixin.py +++ b/griptape/events/event_bus.py @@ -9,7 +9,7 @@ @define -class EventPublisherMixin: +class _EventBus: event_listeners: list[EventListener] = field(factory=list, kw_only=True) def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: @@ -32,3 +32,6 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: for event_listener in self.event_listeners: event_listener.publish_event(event, flush=flush) + + +EventBus = _EventBus() diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 944027c59..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -4,7 +4,6 @@ from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin -from .event_publisher_mixin import EventPublisherMixin __all__ = [ "ActivityMixin", @@ -13,5 +12,4 @@ "RuleMixin", "BlobArtifactFileOutputMixin", "SerializableMixin", - "EventPublisherMixin", ] diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 079e0b741..df7113c23 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,13 +28,11 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events.finish_structure_run_event import FinishStructureRunEvent -from griptape.events.start_structure_run_event import StartStructureRunEvent +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.mixins import EventPublisherMixin from griptape.utils import deprecation_warn if TYPE_CHECKING: @@ -44,7 +42,7 @@ @define -class Structure(ABC, EventPublisherMixin): +class Structure(ABC): LOGGER_NAME = "griptape" id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) @@ -97,8 +95,6 @@ def __attrs_post_init__(self) -> None: if self.conversation_memory is not None: self.conversation_memory.structure = self - self.config.structure = self - tasks = self.tasks.copy() self.tasks.clear() self.add_tasks(*tasks) @@ -261,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - self.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -273,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - self.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index cde59d0ef..07f49f52a 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - self.structure.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - self.structure.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 8c50e4df9..9a8361e6c 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import FinishTaskEvent, StartTaskEvent +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - self.structure.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index bf33e5df8..4a7899b2a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,10 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events.completion_chunk_event import CompletionChunkEvent -from griptape.events.event_listener import EventListener -from griptape.events.finish_prompt_event import FinishPromptEvent -from griptape.events.finish_structure_run_event import FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -64,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - self.structure.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - self.structure.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index b9e3477e4..96a68628f 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -1,7 +1,6 @@ import pytest from griptape.config import StructureConfig -from griptape.structures import Agent class TestStructureConfig: @@ -61,37 +60,3 @@ def test_dot_update(self, config): config.prompt_driver.max_tokens = 10 assert config.prompt_driver.max_tokens == 10 - - def test_drivers(self, config): - assert config.drivers == [ - config.prompt_driver, - config.image_generation_driver, - config.image_query_driver, - config.embedding_driver, - config.vector_store_driver, - config.conversation_memory_driver, - config.text_to_speech_driver, - config.audio_transcription_driver, - ] - - def test_structure(self, config): - structure_1 = Agent( - config=config, - ) - - assert config.structure == structure_1 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 - - structure_2 = Agent( - config=config, - ) - assert config.structure == structure_2 - assert config._event_listener is not None - for driver in config.drivers: - if driver is not None: - assert config._event_listener in driver.event_listeners - assert len(driver.event_listeners) == 1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..0be2f9758 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,12 @@ +import pytest + +from griptape.events import EventBus + + +@pytest.fixture(autouse=True) +def event_bus(): + EventBus.event_listeners = [] + + yield EventBus + + EventBus.event_listeners = [] diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 519e40f57..fc41837fd 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 7447b2c08..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -14,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -30,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -52,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -80,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index 14de15f2d..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 2708b0a88..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.mixins import EventPublisherMixin +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(EventPublisherMixin, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) @@ -42,8 +42,7 @@ def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): - pipeline = Pipeline() - result = MockPromptDriver(stream=True, event_listeners=pipeline.event_listeners).run(PromptStack(messages=[])) + result = MockPromptDriver(stream=True).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 8af5dc827..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events.event_listener import EventListener +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - driver.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py new file mode 100644 index 000000000..fd862913e --- /dev/null +++ b/tests/unit/events/test_event_bus.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +from griptape.events import EventBus, EventListener +from tests.mocks.mock_event import MockEvent + + +class TestEventBus: + def test_add_event_listeners(self): + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listeners(self): + listeners = [EventListener(), EventListener()] + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 + + def test_add_event_listener(self): + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + + assert len(EventBus.event_listeners) == 2 + + def test_remove_event_listener(self): + listener = EventListener() + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) + + assert len(EventBus.event_listeners) == 0 + + def test_remove_unknown_event_listener(self): + EventBus.remove_event_listener(EventListener()) + + def test_publish_event(self): + # Given + mock_handler = Mock() + mock_handler.return_value = None + EventBus.event_listeners = [EventListener(handler=mock_handler)] + mock_event = MockEvent() + + # When + EventBus.publish_event(mock_event) + + # Then + mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index b245c2be9..5601aef34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -37,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -58,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - pipeline.event_listeners = [ + EventBus.event_listeners = [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), EventListener(start_task_event_handler, event_types=[StartTaskEvent]), @@ -86,25 +87,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - pipeline.event_listeners = [] + EventBus.event_listeners = [] mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(pipeline.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - pipeline.remove_event_listener(event_listener_1) - pipeline.remove_event_listener(event_listener_3) - pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) - assert len(pipeline.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/mixins/test_events_mixin.py b/tests/unit/mixins/test_events_mixin.py deleted file mode 100644 index 99f5541ba..000000000 --- a/tests/unit/mixins/test_events_mixin.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import Mock - -from griptape.events import EventListener -from griptape.mixins import EventPublisherMixin -from tests.mocks.mock_event import MockEvent - - -class TestEventsMixin: - def test_init(self): - assert EventPublisherMixin() - - def test_add_event_listeners(self): - mixin = EventPublisherMixin() - - mixin.add_event_listeners([EventListener(), EventListener()]) - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listeners(self): - mixin = EventPublisherMixin() - - listeners = [EventListener(), EventListener()] - mixin.add_event_listeners(listeners) - mixin.remove_event_listeners(listeners) - assert len(mixin.event_listeners) == 0 - - def test_add_event_listener(self): - mixin = EventPublisherMixin() - - mixin.add_event_listener(EventListener()) - mixin.add_event_listener(EventListener()) - - assert len(mixin.event_listeners) == 2 - - def test_remove_event_listener(self): - mixin = EventPublisherMixin() - - listener = EventListener() - mixin.add_event_listener(listener) - mixin.remove_event_listener(listener) - - assert len(mixin.event_listeners) == 0 - - def test_remove_unknown_event_listener(self): - mixin = EventPublisherMixin() - - mixin.remove_event_listener(EventListener()) - - def test_publish_event(self): - # Given - mock_handler = Mock() - mock_handler.return_value = None - mixin = EventPublisherMixin(event_listeners=[EventListener(handler=mock_handler)]) - mock_event = MockEvent() - - # When - mixin.publish_event(mock_event) - - # Then - mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 4f4b43d40..636515106 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ import pytest from griptape.artifacts import TextArtifact +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -15,11 +16,11 @@ class TestBaseTask: @pytest.fixture() def task(self): + EventBus.event_listeners = [EventListener(handler=Mock())] agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], - event_listeners=[EventListener(handler=Mock())], ) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -117,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert task.structure.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 From 39f75c4d897560b22a4c7a81a3192ee47d30ea35 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 09:11:00 -0700 Subject: [PATCH 29/40] Update docs --- .../drivers/event-listener-drivers.md | 89 +++++++++++-------- docs/griptape-framework/misc/events.md | 14 +-- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 73453afb6..db02cd77a 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,26 +14,27 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( - handler=lambda event: { # You can optionally use the handler to transform the event payload before sending it to the Driver - "event": event.to_dict(), - }, driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -83,23 +84,26 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, + EventListener, EventBus ) from griptape.rules import Rule from griptape.structures import Agent -agent = Agent( - rules=[ - Rule( - value="You will be provided with a block of text, and your task is to extract a list of keywords from it." - ) - ], - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( driver=AmazonSqsEventListenerDriver( queue_url=os.environ["AMAZON_SQS_QUEUE_URL"], ), ), + ] +) + +agent = Agent( + rules=[ + Rule( + value="You will be provided with a block of text, and your task is to extract a list of keywords from it." + ) ], ) @@ -128,10 +132,23 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.rules import Rule from griptape.structures import Agent +EventBus.add_event_listeners( + [ + EventListener( + event_types=[FinishStructureRunEvent], + driver=AwsIotCoreEventListenerDriver( + topic=os.environ["AWS_IOT_CORE_TOPIC"], + iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], + ), + ), + ] +) + agent = Agent( rules=[ Rule( @@ -143,15 +160,6 @@ agent = Agent( model="gpt-3.5-turbo", temperature=0.7 ) ), - event_listeners=[ - EventListener( - event_types=[FinishStructureRunEvent], - driver=AwsIotCoreEventListenerDriver( - topic=os.environ["AWS_IOT_CORE_TOPIC"], - iot_endpoint=os.environ["AWS_IOT_CORE_ENDPOINT"], - ), - ), - ], ) agent.run("I want to fly from Orlando to Boston") @@ -171,18 +179,19 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], # By default, GriptapeCloudEventListenerDriver uses the api key provided # in the GT_CLOUD_API_KEY environment variable. driver=GriptapeCloudEventListenerDriver(), ), - ], + ] ) agent.run( @@ -201,20 +210,23 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=WebhookEventListenerDriver( webhook_url=os.environ["WEBHOOK_URL"], ), ), - ], + ] ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` ### Pusher @@ -229,12 +241,13 @@ import os from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, - FinishStructureRunEvent + FinishStructureRunEvent, + EventBus ) from griptape.structures import Agent -agent = Agent( - event_listeners=[ +EventBus.add_event_listeners( + [ EventListener( event_types=[FinishStructureRunEvent], driver=PusherEventListenerDriver( @@ -250,6 +263,8 @@ agent = Agent( ], ) +agent = Agent() + agent.run("Analyze the pros and cons of remote work vs. office work") ``` diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 187321dc6..23ebcdc2a 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -30,7 +30,7 @@ from griptape.events import ( def handler(event: BaseEvent): print(event.__class__) -EventBus.event_listeners=[ +EventBus.add_event_listeners([ EventListener( handler, event_types=[ @@ -42,7 +42,7 @@ EventBus.event_listeners=[ FinishPromptEvent, ], ) - ] + ]) agent = Agent() @@ -140,12 +140,12 @@ from griptape.drivers import OpenAiChatPromptDriver -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: print(e.token, end="", flush=True), event_types=[CompletionChunkEvent], ) -] +]) pipeline = Pipeline( config=OpenAiStructureConfig( @@ -194,12 +194,12 @@ from griptape.structures import Agent token_counter = utils.TokenCounter() -EventBus.event_listeners = [ +EventBus.add_event_listeners([ EventListener( lambda e: token_counter.add_tokens(e.token_count), event_types=[StartPromptEvent, FinishPromptEvent], ) -] +]) def count_tokens(e: BaseEvent): if isinstance(e, StartPromptEvent) or isinstance(e, FinishPromptEvent): @@ -248,7 +248,7 @@ from griptape.structures import Agent from griptape.events import BaseEvent, StartPromptEvent, EventListener, EventBus -EventBus.event_listeners = [EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])] +EventBus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): From 8a97313f9a224eb143975b9c670f90a874506141 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:03:17 -0700 Subject: [PATCH 30/40] Make event listeners private --- griptape/events/event_bus.py | 19 +++++++++++----- tests/unit/conftest.py | 4 ++-- tests/unit/events/test_event_bus.py | 2 +- tests/unit/events/test_event_listener.py | 28 +++++++++++++----------- tests/unit/tasks/test_base_task.py | 2 +- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 9239e66bd..6ffd65550 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -10,7 +10,11 @@ @define class _EventBus: - event_listeners: list[EventListener] = field(factory=list, kw_only=True) + _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") + + @property + def event_listeners(self) -> list[EventListener]: + return self._event_listeners def add_event_listeners(self, event_listeners: list[EventListener]) -> list[EventListener]: return [self.add_event_listener(event_listener) for event_listener in event_listeners] @@ -20,18 +24,21 @@ def remove_event_listeners(self, event_listeners: list[EventListener]) -> None: self.remove_event_listener(event_listener) def add_event_listener(self, event_listener: EventListener) -> EventListener: - if event_listener not in self.event_listeners: - self.event_listeners.append(event_listener) + if event_listener not in self._event_listeners: + self._event_listeners.append(event_listener) return event_listener def remove_event_listener(self, event_listener: EventListener) -> None: - if event_listener in self.event_listeners: - self.event_listeners.remove(event_listener) + if event_listener in self._event_listeners: + self._event_listeners.remove(event_listener) def publish_event(self, event: BaseEvent, *, flush: bool = False) -> None: - for event_listener in self.event_listeners: + for event_listener in self._event_listeners: event_listener.publish_event(event, flush=flush) + def clear_event_listeners(self) -> None: + self._event_listeners.clear() + EventBus = _EventBus() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0be2f9758..7a73b041f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) def event_bus(): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() yield EventBus - EventBus.event_listeners = [] + EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index fd862913e..d237bb3b4 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -35,7 +35,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.event_listeners = [EventListener(handler=mock_handler)] + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 5601aef34..f3d9823d3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,17 +59,19 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), - ] + EventBus.add_event_listeners( + [ + EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), + EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), + EventListener(start_task_event_handler, event_types=[StartTaskEvent]), + EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), + EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), + EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), + EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), + EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), + EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + ] + ) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -87,7 +89,7 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.event_listeners = [] + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 636515106..d6e4da8b6 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.event_listeners = [EventListener(handler=Mock())] + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), From 0e1d019d94d56c4ab59ef56f53c1f6fe5dc18678 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:17:34 -0700 Subject: [PATCH 31/40] Rename EventBus to event_bus --- CHANGELOG.md | 4 +-- .../drivers/event-listener-drivers.md | 24 +++++++-------- docs/griptape-framework/misc/events.md | 22 +++++++------- .../base_audio_transcription_driver.py | 6 ++-- .../base_image_generation_driver.py | 6 ++-- .../image_query/base_image_query_driver.py | 6 ++-- griptape/drivers/prompt/base_prompt_driver.py | 12 ++++---- .../base_text_to_speech_driver.py | 6 ++-- griptape/events/__init__.py | 4 +-- griptape/events/event_bus.py | 2 +- griptape/structures/structure.py | 6 ++-- griptape/tasks/actions_subtask.py | 6 ++-- griptape/tasks/base_task.py | 6 ++-- griptape/utils/stream.py | 6 ++-- tests/unit/conftest.py | 10 +++---- .../test_base_audio_transcription_driver.py | 4 +-- .../test_base_image_generation_driver.py | 10 +++---- .../test_base_image_query_driver.py | 4 +-- .../drivers/prompt/test_base_prompt_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 4 +-- tests/unit/events/test_event_bus.py | 30 +++++++++---------- tests/unit/events/test_event_listener.py | 30 +++++++++---------- tests/unit/tasks/test_base_task.py | 6 ++-- 23 files changed, 109 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76f705ddb..9e016228c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. -- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index db02cd77a..c3c92cfe1 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -14,12 +14,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -84,12 +84,12 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver from griptape.events import ( - EventListener, EventBus + EventListener, event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( @@ -132,12 +132,12 @@ from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriv from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.rules import Rule from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -179,11 +179,11 @@ from griptape.drivers import GriptapeCloudEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -210,11 +210,11 @@ from griptape.drivers import WebhookEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], @@ -242,11 +242,11 @@ from griptape.drivers import PusherEventListenerDriver from griptape.events import ( EventListener, FinishStructureRunEvent, - EventBus + event_bus ) from griptape.structures import Agent -EventBus.add_event_listeners( +event_bus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 23ebcdc2a..b3f4a77fd 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types @@ -23,14 +23,14 @@ from griptape.events import ( StartPromptEvent, FinishPromptEvent, EventListener, - EventBus + event_bus ) def handler(event: BaseEvent): print(event.__class__) -EventBus.add_event_listeners([ +event_bus.add_event_listeners([ EventListener( handler, event_types=[ @@ -69,7 +69,7 @@ Or listen to all events: ```python from griptape.structures import Agent -from griptape.events import BaseEvent, EventListener, EventBus +from griptape.events import BaseEvent, EventListener, event_bus @@ -80,7 +80,7 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -EventBus.event_listeners=[ +event_bus.event_listeners=[ EventListener(handler1), EventListener(handler2), ] @@ -131,7 +131,7 @@ Handler 2 None: - EventBus.publish_event(StartAudioTranscriptionEvent()) + event_bus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - EventBus.publish_event(FinishAudioTranscriptionEvent()) + event_bus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 8dfca5945..360fba8c9 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent +from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishImageGenerationEvent()) + event_bus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 28c571328..b1050b85c 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent +from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - EventBus.publish_event( + event_bus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - EventBus.publish_event(FinishImageQueryEvent(result=result)) + event_bus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 94e46e75d..8044469b5 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -128,12 +128,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + event_bus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) + event_bus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas result = self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index cb11cc498..c74264dc1 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - EventBus.publish_event(FinishTextToSpeechEvent()) + event_bus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..431927663 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import EventBus +from .event_bus import event_bus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "EventBus", + "event_bus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index 6ffd65550..a956f7deb 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -41,4 +41,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -EventBus = _EventBus() +event_bus = _EventBus() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index df7113c23..d68457ebc 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -28,7 +28,7 @@ VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage -from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent +from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -257,7 +257,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - EventBus.publish_event( + event_bus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -269,7 +269,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 07f49f52a..d600c80a5 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -10,7 +10,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent +from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -91,7 +91,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - EventBus.publish_event( + event_bus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -157,7 +157,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - EventBus.publish_event( + event_bus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 9a8361e6c..ade656f87 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent +from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -127,7 +127,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -139,7 +139,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - EventBus.publish_event( + event_bus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 4a7899b2a..fd64a0f52 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus if TYPE_CHECKING: from collections.abc import Iterator @@ -61,8 +61,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - EventBus.add_event_listener(stream_event_listener) + event_bus.add_event_listener(stream_event_listener) self.structure.run(*args) - EventBus.remove_event_listener(stream_event_listener) + event_bus.remove_event_listener(stream_event_listener) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a73b041f..e462ede90 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.events import EventBus +from griptape.events import event_bus @pytest.fixture(autouse=True) -def event_bus(): - EventBus.clear_event_listeners() +def mock_event_bus(): + event_bus.clear_event_listeners() - yield EventBus + yield event_bus - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index fc41837fd..6fcab26e5 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..ab7b33ae8 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..d8ba6b60f 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..52b7d5c0d 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _EventBus +from griptape.events.event_bus import _event_bus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_EventBus, "publish_event") + mock_publish_event = mocker.patch.object(_event_bus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..19493aa0f 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + event_bus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index d237bb3b4..7eb87036a 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,45 @@ from unittest.mock import Mock -from griptape.events import EventBus, EventListener +from griptape.events import EventListener, event_bus from tests.mocks.mock_event import MockEvent class TestEventBus: def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(), EventListener()]) - assert len(EventBus.event_listeners) == 2 + event_bus.add_event_listeners([EventListener(), EventListener()]) + assert len(event_bus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - EventBus.add_event_listeners(listeners) - EventBus.remove_event_listeners(listeners) - assert len(EventBus.event_listeners) == 0 + event_bus.add_event_listeners(listeners) + event_bus.remove_event_listeners(listeners) + assert len(event_bus.event_listeners) == 0 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener()) - EventBus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) + event_bus.add_event_listener(EventListener()) - assert len(EventBus.event_listeners) == 2 + assert len(event_bus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - EventBus.add_event_listener(listener) - EventBus.remove_event_listener(listener) + event_bus.add_event_listener(listener) + event_bus.remove_event_listener(listener) - assert len(EventBus.event_listeners) == 0 + assert len(event_bus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - EventBus.remove_event_listener(EventListener()) + event_bus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + event_bus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - EventBus.publish_event(mock_event) + event_bus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index f3d9823d3..50763e0c3 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,7 +4,6 @@ from griptape.events import ( CompletionChunkEvent, - EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -14,6 +13,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -38,7 +38,7 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -59,7 +59,7 @@ def test_typed_listeners(self, pipeline): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - EventBus.add_event_listeners( + event_bus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -89,25 +89,25 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - EventBus.clear_event_listeners() + event_bus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) + event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) - assert len(EventBus.event_listeners) == 4 + assert len(event_bus.event_listeners) == 4 - EventBus.remove_event_listener(event_listener_1) - EventBus.remove_event_listener(event_listener_3) - EventBus.remove_event_listener(event_listener_4) - EventBus.remove_event_listener(event_listener_5) - assert len(EventBus.event_listeners) == 0 + event_bus.remove_event_listener(event_listener_1) + event_bus.remove_event_listener(event_listener_3) + event_bus.remove_event_listener(event_listener_4) + event_bus.remove_event_listener(event_listener_5) + assert len(event_bus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index d6e4da8b6..aa402bb48 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import EventBus +from griptape.events import event_bus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -16,7 +16,7 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( prompt_driver=MockPromptDriver(), embedding_driver=MockEmbeddingDriver(), @@ -118,4 +118,4 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert event_bus.event_listeners[0].handler.call_count == 2 From 4d491c234f9fed6abbd7fe00fcb8858b2ab49c26 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 10:29:50 -0700 Subject: [PATCH 32/40] Fix doc --- docs/griptape-framework/misc/events.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b3f4a77fd..ebab3c460 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -80,10 +80,11 @@ def handler1(event: BaseEvent): def handler2(event: BaseEvent): print("Handler 2", event.__class__) -event_bus.event_listeners=[ +event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] +) agent = Agent() From 863bcde112e3bb360950448cef4a115de406fbb9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 11:05:52 -0700 Subject: [PATCH 33/40] Fix test --- tests/unit/drivers/prompt/test_base_prompt_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 52b7d5c0d..5b6b0c600 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.events.event_bus import _event_bus +from griptape.events.event_bus import _EventBus from griptape.structures import Pipeline from griptape.tasks import PromptTask, ToolkitTask from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver @@ -27,7 +27,7 @@ def test_run_via_pipeline_retries_failure(self): assert isinstance(pipeline.run().output_task.output, ErrorArtifact) def test_run_via_pipeline_publishes_events(self, mocker): - mock_publish_event = mocker.patch.object(_event_bus, "publish_event") + mock_publish_event = mocker.patch.object(_EventBus, "publish_event") driver = MockPromptDriver() pipeline = Pipeline(prompt_driver=driver) pipeline.add_task(PromptTask("test")) From 12efa677c7ffa413877d56d18fca5bd637c3e485 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:11:31 -0700 Subject: [PATCH 34/40] Fix doc --- docs/griptape-framework/misc/events.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index bfc8ee8ec..e70909074 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -84,10 +84,7 @@ event_bus.add_event_listeners([ EventListener(handler1), EventListener(handler2), ] - -agent = Agent() - -agent = Agent() +) agent = Agent() From 0d5ce93d6ceb5c6b6ba1ace203932e9179f2ea7b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:18:40 -0700 Subject: [PATCH 35/40] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04b592c4e..375773c3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `EventPublisherMixin`. - **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. - **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. +- **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. +- All Task and Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. ## [0.29.0] - 2024-07-30 From daa171031416d36eb846a92b803e8f4b40d09939 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 8 Aug 2024 12:47:46 -0700 Subject: [PATCH 36/40] Update changelog --- CHANGELOG.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 375773c3b..a528d9f81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,19 +11,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. -- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. -- Global config, `griptape.config.config`, for setting global configuration defaults. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Global config, `griptape.config.config`, for setting global configuration defaults. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. -- **BREAKING**: Removed `Workflow.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.stream` and `Workflow.stream`. `Agent.stream` has not been removed. - **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. - **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. +- **BREAKING**: Removed `BaseStructureConfig.merge_config`. +- **BREAKING**: Renamed `StructureConfig` to `DriverConfig`, and renamed fields accordingly. +- Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. -- All Task and Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. ## [0.29.0] - 2024-07-30 From 11f9ac8da0701ae43fde15dcdec1ef62fc75fe83 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 9 Aug 2024 09:24:55 -0700 Subject: [PATCH 37/40] Fix type errors --- Makefile | 2 +- .../drivers/src/prompt_drivers_1.py | 5 +---- .../drivers/src/prompt_drivers_10.py | 7 ++----- .../drivers/src/prompt_drivers_11.py | 9 +++------ .../drivers/src/prompt_drivers_12.py | 9 +++------ .../drivers/src/prompt_drivers_13.py | 7 ++----- .../drivers/src/prompt_drivers_14.py | 9 +++------ .../drivers/src/prompt_drivers_3.py | 15 ++++++--------- .../drivers/src/prompt_drivers_4.py | 7 ++----- .../drivers/src/prompt_drivers_5.py | 13 +++++-------- .../drivers/src/prompt_drivers_6.py | 9 +++------ .../drivers/src/prompt_drivers_7.py | 9 +++------ .../drivers/src/prompt_drivers_8.py | 9 +++------ .../drivers/src/prompt_drivers_9.py | 7 ++----- docs/griptape-framework/misc/src/events_3.py | 12 +++++++----- 15 files changed, 46 insertions(+), 83 deletions(-) diff --git a/Makefile b/Makefile index f1db966f0..73175b7c5 100644 --- a/Makefile +++ b/Makefile @@ -68,7 +68,7 @@ check/lint: .PHONY: check/types check/types: - @poetry run pyright griptape/ docs/**/src/** + @poetry run pyright griptape $(shell find docs -type f -path "*/src/*") .PHONY: check/spell check/spell: diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_1.py b/docs/griptape-framework/drivers/src/prompt_drivers_1.py index 978435f2d..ab5273228 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_1.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_1.py @@ -1,12 +1,9 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), - ), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[Rule(value="Output only the sentiment.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_10.py b/docs/griptape-framework/drivers/src/prompt_drivers_10.py index 02f083570..04e2acb35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_10.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_10.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OllamaPromptDriver from griptape.structures import Agent from griptape.tools import Calculator agent = Agent( - config=StructureConfig( - prompt_driver=OllamaPromptDriver( - model="llama3.1", - ), + prompt_driver=OllamaPromptDriver( + model="llama3.1", ), tools=[Calculator()], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_11.py b/docs/griptape-framework/drivers/src/prompt_drivers_11.py index 1c81c4785..9e838473c 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_11.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_11.py @@ -1,16 +1,13 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="HuggingFaceH4/zephyr-7b-beta", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ) + prompt_driver=HuggingFaceHubPromptDriver( + model="HuggingFaceH4/zephyr-7b-beta", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_12.py b/docs/griptape-framework/drivers/src/prompt_drivers_12.py index d6f59f96e..d555c32c9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_12.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_12.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="http://127.0.0.1:8080", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ), + prompt_driver=HuggingFaceHubPromptDriver( + model="http://127.0.0.1:8080", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_13.py b/docs/griptape-framework/drivers/src/prompt_drivers_13.py index e4fe5208c..d3ddd9093 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_13.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_13.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFacePipelinePromptDriver( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - ) + prompt_driver=HuggingFacePipelinePromptDriver( + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_14.py b/docs/griptape-framework/drivers/src/prompt_drivers_14.py index 85bd5216e..228a5f9b2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_14.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_14.py @@ -1,17 +1,14 @@ import os -from griptape.config import StructureConfig from griptape.drivers import ( AmazonSageMakerJumpstartPromptDriver, ) from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonSageMakerJumpstartPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], - model="meta-llama/Meta-Llama-3-8B-Instruct", - ) + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], + model="meta-llama/Meta-Llama-3-8B-Instruct", ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_3.py b/docs/griptape-framework/drivers/src/prompt_drivers_3.py index b92596aca..8e85ce887 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_3.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_3.py @@ -1,19 +1,16 @@ import os -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - api_key=os.environ["OPENAI_API_KEY"], - temperature=0.1, - model="gpt-4o", - response_format="json_object", - seed=42, - ) + prompt_driver=OpenAiChatPromptDriver( + api_key=os.environ["OPENAI_API_KEY"], + temperature=0.1, + model="gpt-4o", + response_format="json_object", + seed=42, ), input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", rules=[Rule(value='Write your output in json with a single key called "css_code".')], diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_4.py b/docs/griptape-framework/drivers/src/prompt_drivers_4.py index b024638b7..bcafb40de 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_4.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_4.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True - ) + prompt_driver=OpenAiChatPromptDriver( + base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True ), rules=[Rule(value="You are a helpful coding assistant.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_5.py b/docs/griptape-framework/drivers/src/prompt_drivers_5.py index ffe9a4e0a..76301d8d9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_5.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_5.py @@ -1,18 +1,15 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AzureOpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-3.5-turbo", - azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - ) + prompt_driver=AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-3.5-turbo", + azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], ), rules=[ Rule( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_6.py b/docs/griptape-framework/drivers/src/prompt_drivers_6.py index 2bd1b00fb..5e4d226a6 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_6.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_6.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import CoherePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=CoherePromptDriver( - model="command-r", - api_key=os.environ["COHERE_API_KEY"], - ) + prompt_driver=CoherePromptDriver( + model="command-r", + api_key=os.environ["COHERE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_7.py b/docs/griptape-framework/drivers/src/prompt_drivers_7.py index dd1c15370..23f3d0c35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_7.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_7.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-opus-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) + prompt_driver=AnthropicPromptDriver( + model="claude-3-opus-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_8.py b/docs/griptape-framework/drivers/src/prompt_drivers_8.py index 1bbf2848c..b6a1c109e 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_8.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_8.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import GooglePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=GooglePromptDriver( - model="gemini-pro", - api_key=os.environ["GOOGLE_API_KEY"], - ) + prompt_driver=GooglePromptDriver( + model="gemini-pro", + api_key=os.environ["GOOGLE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_9.py b/docs/griptape-framework/drivers/src/prompt_drivers_9.py index cdd0db82d..992dbecd2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_9.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_9.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import AmazonBedrockPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - ) + prompt_driver=AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", ), rules=[ Rule( diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index a99a412eb..ab995e018 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,4 +1,6 @@ -from griptape.config import OpenAiDriverConfig +from typing import cast + +from griptape.config import OpenAiDriverConfig, config from griptape.drivers import OpenAiChatPromptDriver from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline @@ -8,15 +10,15 @@ event_bus.add_event_listeners( [ EventListener( - lambda e: print(e.token, end="", flush=True), + lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True), event_types=[CompletionChunkEvent], ) ] ) -pipeline = Pipeline( - config=OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)), -) +config.drivers = OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)) + +pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( From d0fab25e896c64ed78920b14d8a47f9528341019 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 12 Aug 2024 11:51:22 -0700 Subject: [PATCH 38/40] Fix utilities checking for stream --- griptape/utils/chat.py | 7 +++++-- griptape/utils/stream.py | 9 ++++++--- tests/unit/utils/test_stream.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 56b53c0ce..07fea92d8 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,9 +25,12 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - from griptape.config import config + from griptape.tasks.prompt_task import PromptTask - if config.drivers.prompt.stream: + streaming_tasks = [ + task for task in self.structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if streaming_tasks: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index c5545bc44..6da58b9e6 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,10 +34,13 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - from griptape.config import config + from griptape.tasks import PromptTask - if not config.drivers.prompt.stream: - raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") + streaming_tasks = [ + task for task in structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if not streaming_tasks: + raise ValueError("Structure does not have any streaming tasks, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index edd0258f2..caddbb1a3 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,19 +2,17 @@ import pytest -from griptape.config import config -from griptape.structures import Agent +from griptape.structures import Agent, Pipeline from griptape.utils import Stream class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - config.drivers.prompt.stream = request.param - return Agent() + return Agent(stream=request.param) def test_init(self, agent): - if config.drivers.prompt.stream: + if agent.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent @@ -29,3 +27,9 @@ def test_init(self, agent): else: with pytest.raises(ValueError): Stream(agent) + + def test_validate_structure_invalid(self): + pipeline = Pipeline(tasks=[]) + + with pytest.raises(ValueError): + Stream(pipeline) From 1ff7e88f293b1d7999bc612afb32db289f5d3b18 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 10:36:27 -0700 Subject: [PATCH 39/40] Clean up example --- docs/griptape-framework/misc/src/events_3.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index ab995e018..bae8b8224 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,6 +1,5 @@ from typing import cast -from griptape.config import OpenAiDriverConfig, config from griptape.drivers import OpenAiChatPromptDriver from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline @@ -16,13 +15,11 @@ ] ) -config.drivers = OpenAiDriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-4o", stream=True)) - pipeline = Pipeline() - pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], ) ) From d1798214c51e97fbb6c963ab2135e996da5e96ee Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 13 Aug 2024 10:37:50 -0700 Subject: [PATCH 40/40] Default stream to config value --- griptape/structures/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 59a865897..a046da6a9 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -23,7 +23,7 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: bool = field(default=False, kw_only=True) + stream: bool = field(default=Factory(lambda: config.drivers.prompt.stream), kw_only=True) prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)