From 9c7109dccc408a10a415c082a50973e392cface4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 19 Aug 2024 11:14:49 -0700 Subject: [PATCH 1/3] Rename driver config fields --- .../src/multiple_agent_shared_memory_1.py | 2 +- .../drivers/src/embedding_drivers_10.py | 8 ++-- .../drivers/src/event_listener_drivers_4.py | 2 +- .../structures/src/config_7.py | 2 +- .../structures/src/task_memory_6.py | 4 +- .../official-tools/src/rest_api_tool_1.py | 2 +- .../drivers/amazon_bedrock_driver_config.py | 10 ++-- .../config/drivers/anthropic_driver_config.py | 8 ++-- .../drivers/azure_openai_driver_config.py | 10 ++-- griptape/config/drivers/base_driver_config.py | 48 ++++++++++--------- .../config/drivers/cohere_driver_config.py | 6 +-- griptape/config/drivers/driver_config.py | 18 +++---- .../config/drivers/google_driver_config.py | 6 +-- .../config/drivers/openai_driver_config.py | 14 +++--- .../audio/audio_transcription_engine.py | 2 +- .../engines/audio/text_to_speech_engine.py | 2 +- .../extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 4 +- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 2 +- .../structure/base_conversation_memory.py | 4 +- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/text_artifact_storage.py | 2 +- griptape/structures/agent.py | 4 +- griptape/tasks/prompt_task.py | 2 +- griptape/tools/query/tool.py | 2 +- griptape/utils/chat.py | 2 +- tests/mocks/mock_driver_config.py | 10 ++-- .../test_amazon_bedrock_driver_config.py | 32 ++++++------- .../drivers/test_anthropic_driver_config.py | 16 +++---- .../test_azure_openai_driver_config.py | 16 +++---- .../drivers/test_cohere_driver_config.py | 16 +++---- .../unit/config/drivers/test_driver_config.py | 20 ++++---- .../drivers/test_google_driver_config.py | 16 +++---- .../drivers/test_openai_driver_config.py | 16 +++---- .../drivers/prompt/test_base_prompt_driver.py | 6 +-- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../structure/test_conversation_memory.py | 8 ++-- tests/unit/structures/test_agent.py | 2 +- tests/unit/tasks/test_structure_run_task.py | 4 +- tests/unit/tasks/test_tool_task.py | 4 +- tests/unit/tasks/test_toolkit_task.py | 6 +-- tests/utils/structure_tester.py | 6 ++- 46 files changed, 184 insertions(+), 174 deletions(-) diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index 946a21190..c6c61f966 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -37,7 +37,7 @@ config.drivers = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store=mongo_driver, - embedding=embedding_driver, + embedding_driver=embedding_driver, ) loader = Agent( diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index dbbc659fb..75e2963d1 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -8,13 +8,13 @@ from griptape.tools import PromptSummaryTool, WebScraperTool config.drivers = DriverConfig( - prompt=OpenAiChatPromptDriver(model="gpt-4o"), - embedding=VoyageAiEmbeddingDriver(), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), + embedding_driver=VoyageAiEmbeddingDriver(), ) config.drivers = DriverConfig( - prompt=OpenAiChatPromptDriver(model="gpt-4o"), - embedding=VoyageAiEmbeddingDriver(), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), + embedding_driver=VoyageAiEmbeddingDriver(), ) agent = Agent( diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py index 7a0957e63..d7f2b121e 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py @@ -7,7 +7,7 @@ from griptape.rules import Rule from griptape.structures import Agent -config.drivers = DriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) +config.drivers = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) event_bus.add_event_listeners( [ EventListener( diff --git a/docs/griptape-framework/structures/src/config_7.py b/docs/griptape-framework/structures/src/config_7.py index 9f464b167..535f6e4d8 100644 --- a/docs/griptape-framework/structures/src/config_7.py +++ b/docs/griptape-framework/structures/src/config_7.py @@ -6,7 +6,7 @@ from griptape.structures import Agent config.drivers = DriverConfig( - prompt=AnthropicPromptDriver( + prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], ) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 3ce87f72d..3eb75a921 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -12,11 +12,11 @@ from griptape.tools import FileManagerTool, QueryTool, WebScraperTool config.drivers = OpenAiDriverConfig( - prompt=OpenAiChatPromptDriver(model="gpt-4"), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) config.drivers = OpenAiDriverConfig( - prompt=OpenAiChatPromptDriver(model="gpt-4"), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) diff --git a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py index 4ef73dd9d..b181d4ec7 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py @@ -9,7 +9,7 @@ from griptape.tools import RestApiTool config.drivers = DriverConfig( - prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), ) posts_client = RestApiTool( diff --git a/griptape/config/drivers/amazon_bedrock_driver_config.py b/griptape/config/drivers/amazon_bedrock_driver_config.py index db409353b..22d198167 100644 --- a/griptape/config/drivers/amazon_bedrock_driver_config.py +++ b/griptape/config/drivers/amazon_bedrock_driver_config.py @@ -30,15 +30,15 @@ class AmazonBedrockDriverConfig(DriverConfig): ) @lazy_property() - def prompt(self) -> AmazonBedrockPromptDriver: + def prompt_driver(self) -> AmazonBedrockPromptDriver: return AmazonBedrockPromptDriver(session=self.session, model="anthropic.claude-3-5-sonnet-20240620-v1:0") @lazy_property() - def embedding(self) -> AmazonBedrockTitanEmbeddingDriver: + def embedding_driver(self) -> AmazonBedrockTitanEmbeddingDriver: return AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1") @lazy_property() - def image_generation(self) -> AmazonBedrockImageGenerationDriver: + def image_generation_driver(self) -> AmazonBedrockImageGenerationDriver: return AmazonBedrockImageGenerationDriver( session=self.session, model="amazon.titan-image-generator-v1", @@ -46,7 +46,7 @@ def image_generation(self) -> AmazonBedrockImageGenerationDriver: ) @lazy_property() - def image_query(self) -> AmazonBedrockImageQueryDriver: + def image_query_driver(self) -> AmazonBedrockImageQueryDriver: return AmazonBedrockImageQueryDriver( session=self.session, model="anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -54,7 +54,7 @@ def image_query(self) -> AmazonBedrockImageQueryDriver: ) @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver( embedding_driver=AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1") ) diff --git a/griptape/config/drivers/anthropic_driver_config.py b/griptape/config/drivers/anthropic_driver_config.py index 399f13cdb..b036d85f4 100644 --- a/griptape/config/drivers/anthropic_driver_config.py +++ b/griptape/config/drivers/anthropic_driver_config.py @@ -13,17 +13,17 @@ @define class AnthropicDriverConfig(DriverConfig): @lazy_property() - def prompt(self) -> AnthropicPromptDriver: + def prompt_driver(self) -> AnthropicPromptDriver: return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620") @lazy_property() - def embedding(self) -> VoyageAiEmbeddingDriver: + def embedding_driver(self) -> VoyageAiEmbeddingDriver: return VoyageAiEmbeddingDriver(model="voyage-large-2") @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")) @lazy_property() - def image_query(self) -> AnthropicImageQueryDriver: + def image_query_driver(self) -> AnthropicImageQueryDriver: return AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620") diff --git a/griptape/config/drivers/azure_openai_driver_config.py b/griptape/config/drivers/azure_openai_driver_config.py index 211a7d209..f27c8970c 100644 --- a/griptape/config/drivers/azure_openai_driver_config.py +++ b/griptape/config/drivers/azure_openai_driver_config.py @@ -41,7 +41,7 @@ class AzureOpenAiDriverConfig(DriverConfig): api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) @lazy_property() - def prompt(self) -> AzureOpenAiChatPromptDriver: + def prompt_driver(self) -> AzureOpenAiChatPromptDriver: return AzureOpenAiChatPromptDriver( model="gpt-4o", azure_endpoint=self.azure_endpoint, @@ -51,7 +51,7 @@ def prompt(self) -> AzureOpenAiChatPromptDriver: ) @lazy_property() - def embedding(self) -> AzureOpenAiEmbeddingDriver: + def embedding_driver(self) -> AzureOpenAiEmbeddingDriver: return AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", azure_endpoint=self.azure_endpoint, @@ -61,7 +61,7 @@ def embedding(self) -> AzureOpenAiEmbeddingDriver: ) @lazy_property() - def image_generation(self) -> AzureOpenAiImageGenerationDriver: + def image_generation_driver(self) -> AzureOpenAiImageGenerationDriver: return AzureOpenAiImageGenerationDriver( model="dall-e-2", azure_endpoint=self.azure_endpoint, @@ -72,7 +72,7 @@ def image_generation(self) -> AzureOpenAiImageGenerationDriver: ) @lazy_property() - def image_query(self) -> AzureOpenAiImageQueryDriver: + def image_query_driver(self) -> AzureOpenAiImageQueryDriver: return AzureOpenAiImageQueryDriver( model="gpt-4o", azure_endpoint=self.azure_endpoint, @@ -82,7 +82,7 @@ def image_query(self) -> AzureOpenAiImageQueryDriver: ) @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver( embedding_driver=AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", diff --git a/griptape/config/drivers/base_driver_config.py b/griptape/config/drivers/base_driver_config.py index 3a8672a26..d8555052f 100644 --- a/griptape/config/drivers/base_driver_config.py +++ b/griptape/config/drivers/base_driver_config.py @@ -23,57 +23,59 @@ @define class BaseDriverConfig(ABC, SerializableMixin): - _prompt: BasePromptDriver = field(kw_only=True, default=None, metadata={"serializable": True}, alias="prompt") - _image_generation: BaseImageGenerationDriver = field( - kw_only=True, default=None, metadata={"serializable": True}, alias="image_generation" + _prompt_driver: BasePromptDriver = field( + kw_only=True, default=None, metadata={"serializable": True}, alias="prompt_driver" ) - _image_query: BaseImageQueryDriver = field( - kw_only=True, default=None, metadata={"serializable": True}, alias="image_query" + _image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=None, metadata={"serializable": True}, alias="image_generation_driver" ) - _embedding: BaseEmbeddingDriver = field( - kw_only=True, default=None, metadata={"serializable": True}, alias="embedding" + _image_query_driver: BaseImageQueryDriver = field( + kw_only=True, default=None, metadata={"serializable": True}, alias="image_query_driver" ) - _vector_store: BaseVectorStoreDriver = field( - default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store" + _embedding_driver: BaseEmbeddingDriver = field( + kw_only=True, default=None, metadata={"serializable": True}, alias="embedding_driver" ) - _conversation_memory: Optional[BaseConversationMemoryDriver] = field( - default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory" + _vector_store_driver: BaseVectorStoreDriver = field( + default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store_driver" ) - _text_to_speech: BaseTextToSpeechDriver = field( - default=None, kw_only=True, metadata={"serializable": True}, alias="text_to_speech" + _conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory_driver" ) - _audio_transcription: BaseAudioTranscriptionDriver = field( - default=None, kw_only=True, metadata={"serializable": True}, alias="audio_transcription" + _text_to_speech_driver: BaseTextToSpeechDriver = field( + default=None, kw_only=True, metadata={"serializable": True}, alias="text_to_speech_driver" + ) + _audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=None, kw_only=True, metadata={"serializable": True}, alias="audio_transcription_driver" ) @lazy_property() @abstractmethod - def prompt(self) -> BasePromptDriver: ... + def prompt_driver(self) -> BasePromptDriver: ... @lazy_property() @abstractmethod - def image_generation(self) -> BaseImageGenerationDriver: ... + def image_generation_driver(self) -> BaseImageGenerationDriver: ... @lazy_property() @abstractmethod - def image_query(self) -> BaseImageQueryDriver: ... + def image_query_driver(self) -> BaseImageQueryDriver: ... @lazy_property() @abstractmethod - def embedding(self) -> BaseEmbeddingDriver: ... + def embedding_driver(self) -> BaseEmbeddingDriver: ... @lazy_property() @abstractmethod - def vector_store(self) -> BaseVectorStoreDriver: ... + def vector_store_driver(self) -> BaseVectorStoreDriver: ... @lazy_property() @abstractmethod - def conversation_memory(self) -> Optional[BaseConversationMemoryDriver]: ... + def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: ... @lazy_property() @abstractmethod - def text_to_speech(self) -> BaseTextToSpeechDriver: ... + def text_to_speech_driver(self) -> BaseTextToSpeechDriver: ... @lazy_property() @abstractmethod - def audio_transcription(self) -> BaseAudioTranscriptionDriver: ... + def audio_transcription_driver(self) -> BaseAudioTranscriptionDriver: ... diff --git a/griptape/config/drivers/cohere_driver_config.py b/griptape/config/drivers/cohere_driver_config.py index ae3b6c184..25dc833e5 100644 --- a/griptape/config/drivers/cohere_driver_config.py +++ b/griptape/config/drivers/cohere_driver_config.py @@ -14,11 +14,11 @@ class CohereDriverConfig(DriverConfig): api_key: str = field(metadata={"serializable": False}, kw_only=True) @lazy_property() - def prompt(self) -> CoherePromptDriver: + def prompt_driver(self) -> CoherePromptDriver: return CoherePromptDriver(model="command-r", api_key=self.api_key) @lazy_property() - def embedding(self) -> CohereEmbeddingDriver: + def embedding_driver(self) -> CohereEmbeddingDriver: return CohereEmbeddingDriver( model="embed-english-v3.0", api_key=self.api_key, @@ -26,7 +26,7 @@ def embedding(self) -> CohereEmbeddingDriver: ) @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver( embedding_driver=CohereEmbeddingDriver( model="embed-english-v3.0", diff --git a/griptape/config/drivers/driver_config.py b/griptape/config/drivers/driver_config.py index b102c4465..16cb9a535 100644 --- a/griptape/config/drivers/driver_config.py +++ b/griptape/config/drivers/driver_config.py @@ -32,33 +32,33 @@ @define class DriverConfig(BaseDriverConfig): @lazy_property() - def prompt(self) -> BasePromptDriver: + def prompt_driver(self) -> BasePromptDriver: return DummyPromptDriver() @lazy_property() - def image_generation(self) -> BaseImageGenerationDriver: + def image_generation_driver(self) -> BaseImageGenerationDriver: return DummyImageGenerationDriver() @lazy_property() - def image_query(self) -> BaseImageQueryDriver: + def image_query_driver(self) -> BaseImageQueryDriver: return DummyImageQueryDriver() @lazy_property() - def embedding(self) -> BaseEmbeddingDriver: + def embedding_driver(self) -> BaseEmbeddingDriver: return DummyEmbeddingDriver() @lazy_property() - def vector_store(self) -> BaseVectorStoreDriver: - return DummyVectorStoreDriver(embedding_driver=self.embedding) + def vector_store_driver(self) -> BaseVectorStoreDriver: + return DummyVectorStoreDriver(embedding_driver=self.embedding_driver) @lazy_property() - def conversation_memory(self) -> Optional[BaseConversationMemoryDriver]: + def conversation_memory_driver(self) -> Optional[BaseConversationMemoryDriver]: return None @lazy_property() - def text_to_speech(self) -> BaseTextToSpeechDriver: + def text_to_speech_driver(self) -> BaseTextToSpeechDriver: return DummyTextToSpeechDriver() @lazy_property() - def audio_transcription(self) -> BaseAudioTranscriptionDriver: + def audio_transcription_driver(self) -> BaseAudioTranscriptionDriver: return DummyAudioTranscriptionDriver() diff --git a/griptape/config/drivers/google_driver_config.py b/griptape/config/drivers/google_driver_config.py index 5cff7ef6d..0ab72e6bb 100644 --- a/griptape/config/drivers/google_driver_config.py +++ b/griptape/config/drivers/google_driver_config.py @@ -12,13 +12,13 @@ @define class GoogleDriverConfig(DriverConfig): @lazy_property() - def prompt(self) -> GooglePromptDriver: + def prompt_driver(self) -> GooglePromptDriver: return GooglePromptDriver(model="gemini-1.5-pro") @lazy_property() - def embedding(self) -> GoogleEmbeddingDriver: + def embedding_driver(self) -> GoogleEmbeddingDriver: return GoogleEmbeddingDriver(model="models/embedding-001") @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")) diff --git a/griptape/config/drivers/openai_driver_config.py b/griptape/config/drivers/openai_driver_config.py index a16f79ed8..49cf60206 100644 --- a/griptape/config/drivers/openai_driver_config.py +++ b/griptape/config/drivers/openai_driver_config.py @@ -16,29 +16,29 @@ @define class OpenAiDriverConfig(DriverConfig): @lazy_property() - def prompt(self) -> OpenAiChatPromptDriver: + def prompt_driver(self) -> OpenAiChatPromptDriver: return OpenAiChatPromptDriver(model="gpt-4o") @lazy_property() - def image_generation(self) -> OpenAiImageGenerationDriver: + def image_generation_driver(self) -> OpenAiImageGenerationDriver: return OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512") @lazy_property() - def image_query(self) -> OpenAiImageQueryDriver: + def image_query_driver(self) -> OpenAiImageQueryDriver: return OpenAiImageQueryDriver(model="gpt-4o") @lazy_property() - def embedding(self) -> OpenAiEmbeddingDriver: + def embedding_driver(self) -> OpenAiEmbeddingDriver: return OpenAiEmbeddingDriver(model="text-embedding-3-small") @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")) @lazy_property() - def text_to_speech(self) -> OpenAiTextToSpeechDriver: + def text_to_speech_driver(self) -> OpenAiTextToSpeechDriver: return OpenAiTextToSpeechDriver(model="tts") @lazy_property() - def audio_transcription(self) -> OpenAiAudioTranscriptionDriver: + def audio_transcription_driver(self) -> OpenAiAudioTranscriptionDriver: return OpenAiAudioTranscriptionDriver(model="whisper-1") diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index cad8287d5..6ab8abc3a 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: config.drivers.audio_transcription), kw_only=True + default=Factory(lambda: config.drivers.audio_transcription_driver), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index aad45a10a..4d058f910 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: config.drivers.text_to_speech), kw_only=True + default=Factory(lambda: config.drivers.text_to_speech_driver), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 8f61bb764..cb152cec7 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 9bec68b91..583d5bc92 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: config.drivers.image_generation) + kw_only=True, default=Factory(lambda: config.drivers.image_generation_driver) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index f2bd99544..393da833a 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -13,7 +13,9 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: config.drivers.image_query), kw_only=True) + image_query_driver: BaseImageQueryDriver = field( + default=Factory(lambda: config.drivers.image_query_driver), kw_only=True + ) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 92e611223..9253122c0 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt)) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 6ce235fa5..4fb2bcbc8 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store_driver)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 065677e1b..3808416f9 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index d6e3549af..8082b3bf4 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: config.drivers.conversation_memory), kw_only=True + default=Factory(lambda: config.drivers.conversation_memory_driver), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = config.drivers.prompt + prompt_driver = config.drivers.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 4263e61c8..ed2807216 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 623c176ea..dd33b0023 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -14,7 +14,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store_driver)) def can_store(self, artifact: BaseArtifact) -> bool: return isinstance(artifact, TextArtifact) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index a046da6a9..549e7c665 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -23,8 +23,8 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: bool = field(default=Factory(lambda: config.drivers.prompt.stream), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) + stream: bool = field(default=Factory(lambda: config.drivers.prompt_driver.stream), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index a8038832d..0f3acc8c6 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -20,7 +20,7 @@ @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 3ecc63bca..eed4c9522 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -25,7 +25,7 @@ class QueryTool(BaseTool, RuleMixin): lambda self: RagEngine( response_stage=ResponseRagStage( response_modules=[ - PromptResponseRagModule(prompt_driver=config.drivers.prompt, rulesets=self.rulesets) + PromptResponseRagModule(prompt_driver=config.drivers.prompt_driver, rulesets=self.rulesets) ], ), ), diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index a5cbce13d..f34a97504 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -53,7 +53,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if config.drivers.prompt.stream: + if config.drivers.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/tests/mocks/mock_driver_config.py b/tests/mocks/mock_driver_config.py index eeab99abc..b038fe920 100644 --- a/tests/mocks/mock_driver_config.py +++ b/tests/mocks/mock_driver_config.py @@ -12,21 +12,21 @@ @define class MockDriverConfig(DriverConfig): @lazy_property() - def prompt(self) -> MockPromptDriver: + def prompt_driver(self) -> MockPromptDriver: return MockPromptDriver() @lazy_property() - def image_generation(self) -> MockImageGenerationDriver: + def image_generation_driver(self) -> MockImageGenerationDriver: return MockImageGenerationDriver() @lazy_property() - def image_query(self) -> MockImageQueryDriver: + def image_query_driver(self) -> MockImageQueryDriver: return MockImageQueryDriver() @lazy_property() - def embedding(self) -> MockEmbeddingDriver: + def embedding_driver(self) -> MockEmbeddingDriver: return MockEmbeddingDriver() @lazy_property() - def vector_store(self) -> LocalVectorStoreDriver: + def vector_store_driver(self) -> LocalVectorStoreDriver: return LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) diff --git a/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py b/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py index a76eeb278..e30444332 100644 --- a/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py +++ b/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py @@ -25,9 +25,9 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory": None, - "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation": { + "conversation_memory_driver": None, + "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation_driver": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -40,13 +40,13 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query": { + "image_query_driver": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt": { + "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -55,7 +55,7 @@ def test_to_dict(self, config): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store": { + "vector_store_driver": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", @@ -63,8 +63,8 @@ def test_to_dict(self, config): "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockDriverConfig", - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): @@ -77,9 +77,9 @@ def test_from_dict_with_values(self, config_with_values): def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory": None, - "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation": { + "conversation_memory_driver": None, + "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation_driver": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -92,13 +92,13 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query": { + "image_query_driver": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt": { + "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -107,7 +107,7 @@ def test_to_dict_with_values(self, config_with_values): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store": { + "vector_store_driver": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", @@ -115,7 +115,7 @@ def test_to_dict_with_values(self, config_with_values): "type": "LocalVectorStoreDriver", }, "type": "AmazonBedrockDriverConfig", - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } assert config_with_values.session.region_name == "region-value" diff --git a/tests/unit/config/drivers/test_anthropic_driver_config.py b/tests/unit/config/drivers/test_anthropic_driver_config.py index a496c47b7..770a04b9f 100644 --- a/tests/unit/config/drivers/test_anthropic_driver_config.py +++ b/tests/unit/config/drivers/test_anthropic_driver_config.py @@ -16,7 +16,7 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "AnthropicDriverConfig", - "prompt": { + "prompt_driver": { "type": "AnthropicPromptDriver", "temperature": 0.1, "max_tokens": 1000, @@ -26,18 +26,18 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, }, - "image_generation": {"type": "DummyImageGenerationDriver"}, - "image_query": { + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": { "type": "AnthropicImageQueryDriver", "model": "claude-3-5-sonnet-20240620", "max_tokens": 256, }, - "embedding": { + "embedding_driver": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store": { + "vector_store_driver": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", @@ -45,9 +45,9 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory": None, - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/drivers/test_azure_openai_driver_config.py b/tests/unit/config/drivers/test_azure_openai_driver_config.py index ef418e097..dfc69ce46 100644 --- a/tests/unit/config/drivers/test_azure_openai_driver_config.py +++ b/tests/unit/config/drivers/test_azure_openai_driver_config.py @@ -20,7 +20,7 @@ def test_to_dict(self, config): assert config.to_dict() == { "type": "AzureOpenAiDriverConfig", "azure_endpoint": "http://localhost:8080", - "prompt": { + "prompt_driver": { "type": "AzureOpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -36,8 +36,8 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory": None, - "embedding": { + "conversation_memory_driver": None, + "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiEmbeddingDriver", }, - "image_generation": { + "image_generation_driver": { "api_version": "2024-02-01", "base_url": None, "image_size": "512x512", @@ -59,7 +59,7 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query": { + "image_query_driver": { "base_url": None, "image_quality": "auto", "max_tokens": 256, @@ -70,7 +70,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiImageQueryDriver", }, - "vector_store": { + "vector_store_driver": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -82,6 +82,6 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/config/drivers/test_cohere_driver_config.py b/tests/unit/config/drivers/test_cohere_driver_config.py index 5a75c98cd..982733dd6 100644 --- a/tests/unit/config/drivers/test_cohere_driver_config.py +++ b/tests/unit/config/drivers/test_cohere_driver_config.py @@ -11,12 +11,12 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "CohereDriverConfig", - "image_generation": {"type": "DummyImageGenerationDriver"}, - "image_query": {"type": "DummyImageQueryDriver"}, - "conversation_memory": None, - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, - "prompt": { + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "prompt_driver": { "type": "CoherePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -25,12 +25,12 @@ def test_to_dict(self, config): "force_single_step": False, "use_native_tools": True, }, - "embedding": { + "embedding_driver": { "type": "CohereEmbeddingDriver", "model": "embed-english-v3.0", "input_type": "search_document", }, - "vector_store": { + "vector_store_driver": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/config/drivers/test_driver_config.py b/tests/unit/config/drivers/test_driver_config.py index 71220646f..4c11a24cd 100644 --- a/tests/unit/config/drivers/test_driver_config.py +++ b/tests/unit/config/drivers/test_driver_config.py @@ -11,29 +11,29 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "DriverConfig", - "prompt": { + "prompt_driver": { "type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False, "use_native_tools": False, }, - "conversation_memory": None, - "embedding": {"type": "DummyEmbeddingDriver"}, - "image_generation": {"type": "DummyImageGenerationDriver"}, - "image_query": {"type": "DummyImageQueryDriver"}, - "vector_store": { + "conversation_memory_driver": None, + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "vector_store_driver": { "embedding_driver": {"type": "DummyEmbeddingDriver"}, "type": "DummyVectorStoreDriver", }, - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): assert DriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() def test_dot_update(self, config): - config.prompt.max_tokens = 10 + config.prompt_driver.max_tokens = 10 - assert config.prompt.max_tokens == 10 + assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/config/drivers/test_google_driver_config.py b/tests/unit/config/drivers/test_google_driver_config.py index 3a16173b5..e16f63eb3 100644 --- a/tests/unit/config/drivers/test_google_driver_config.py +++ b/tests/unit/config/drivers/test_google_driver_config.py @@ -15,7 +15,7 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "GoogleDriverConfig", - "prompt": { + "prompt_driver": { "type": "GooglePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -26,15 +26,15 @@ def test_to_dict(self, config): "tool_choice": "auto", "use_native_tools": True, }, - "image_generation": {"type": "DummyImageGenerationDriver"}, - "image_query": {"type": "DummyImageQueryDriver"}, - "embedding": { + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "embedding_driver": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store": { + "vector_store_driver": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", @@ -43,9 +43,9 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory": None, - "text_to_speech": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/drivers/test_openai_driver_config.py b/tests/unit/config/drivers/test_openai_driver_config.py index 860f70518..5c560a7f7 100644 --- a/tests/unit/config/drivers/test_openai_driver_config.py +++ b/tests/unit/config/drivers/test_openai_driver_config.py @@ -15,7 +15,7 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { "type": "OpenAiDriverConfig", - "prompt": { + "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -28,14 +28,14 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory": None, - "embedding": { + "conversation_memory_driver": None, + "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation": { + "image_generation_driver": { "api_version": None, "base_url": None, "image_size": "512x512", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query": { + "image_query_driver": { "api_version": None, "base_url": None, "image_quality": "auto", @@ -55,7 +55,7 @@ def test_to_dict(self, config): "organization": None, "type": "OpenAiImageQueryDriver", }, - "vector_store": { + "vector_store_driver": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -64,7 +64,7 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech": { + "text_to_speech_driver": { "type": "OpenAiTextToSpeechDriver", "api_version": None, "base_url": None, @@ -73,7 +73,7 @@ def test_to_dict(self, config): "organization": None, "voice": "alloy", }, - "audio_transcription": { + "audio_transcription_driver": { "type": "OpenAiAudioTranscriptionDriver", "api_version": None, "base_url": None, diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index c30acdec4..c575a01d3 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.drivers.prompt = MockPromptDriver(max_attempts=2) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.drivers.prompt = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -46,7 +46,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.drivers.prompt = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index c2bb45208..b2e9c069b 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.drivers.prompt = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 4e21fa392..e66b8816b 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.drivers.prompt = MockPromptDriver(stream=True) + mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index f0e4b0af3..06e54e6c4 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,7 +119,9 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + mock_config.drivers.prompt_driver = MockPromptDriver( + tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) + ) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -143,7 +145,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index ef5faeff1..36a73db74 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -222,7 +222,7 @@ def test_task_memory_defaults(self, mock_config): storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.vector_store_driver.embedding_driver == mock_config.drivers.embedding + assert storage.vector_store_driver.embedding_driver == mock_config.drivers.embedding_driver def finished_tasks(self): task = PromptTask("test prompt") diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index d18d75d75..2c0dc1b28 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.drivers.prompt = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.drivers.prompt = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 18521632e..70ab05e12 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,7 +168,9 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.drivers.prompt = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + mock_config.drivers.prompt_driver = MockPromptDriver( + mock_output=f"```python foo bar\n{json.dumps(output_dict)}" + ) return Agent() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index c1b91b1ed..15f5a59b1 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.drivers.prompt.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.drivers.prompt.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.drivers.prompt.mock_output = output + mock_config.drivers.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index d87fc095e..317fc0e84 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -25,7 +25,9 @@ def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: return [ - prompt_driver_option.prompt for prompt_driver_option in prompt_drivers_options if prompt_driver_option.enabled + prompt_driver_option.prompt_driver + for prompt_driver_option in prompt_drivers_options + if prompt_driver_option.enabled ] @@ -228,7 +230,7 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: def verify_structure_output(self, structure) -> dict: from griptape.config import config - config.drivers.prompt = AzureOpenAiChatPromptDriver( + config.drivers.prompt_driver = AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], From 83f5e806ebd91fa237a12ba567550c17f5a22d03 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 19 Aug 2024 11:19:44 -0700 Subject: [PATCH 2/3] Rename Config fields --- .../src/multiple_agent_shared_memory_1.py | 4 +-- docs/examples/src/talk_to_a_video_1.py | 2 +- .../drivers/src/embedding_drivers_10.py | 4 +-- .../drivers/src/event_listener_drivers_4.py | 2 +- .../structures/src/config_1.py | 2 +- .../structures/src/config_2.py | 2 +- .../structures/src/config_3.py | 2 +- .../structures/src/config_4.py | 2 +- .../structures/src/config_5.py | 2 +- .../structures/src/config_6.py | 2 +- .../structures/src/config_7.py | 2 +- .../structures/src/config_8.py | 2 +- .../structures/src/config_logging.py | 4 +-- .../structures/src/task_memory_6.py | 4 +-- .../official-tools/src/rest_api_tool_1.py | 2 +- griptape/config/base_config.py | 8 ++--- griptape/config/config.py | 34 +++++++++---------- .../audio/audio_transcription_engine.py | 2 +- .../engines/audio/text_to_speech_engine.py | 2 +- .../extraction/base_extraction_engine.py | 2 +- .../image/base_image_generation_engine.py | 2 +- .../engines/image_query/image_query_engine.py | 2 +- .../response/prompt_response_rag_module.py | 2 +- .../vector_store_retrieval_rag_module.py | 4 ++- .../engines/summary/prompt_summary_engine.py | 2 +- .../structure/base_conversation_memory.py | 4 +-- .../structure/summary_conversation_memory.py | 2 +- .../task/storage/text_artifact_storage.py | 4 ++- griptape/structures/agent.py | 4 +-- griptape/tasks/actions_subtask.py | 2 +- griptape/tasks/base_audio_generation_task.py | 2 +- griptape/tasks/base_audio_input_task.py | 2 +- griptape/tasks/base_image_generation_task.py | 2 +- griptape/tasks/base_multi_text_input_task.py | 2 +- griptape/tasks/base_task.py | 2 +- griptape/tasks/base_text_input_task.py | 2 +- griptape/tasks/prompt_task.py | 4 +-- griptape/tools/query/tool.py | 4 ++- griptape/utils/chat.py | 8 ++--- .../logging/test_newline_logging_filter.py | 2 +- .../logging/test_truncate_logging_filter.py | 2 +- tests/unit/config/test_config.py | 16 ++++----- tests/unit/conftest.py | 2 +- .../drivers/prompt/test_base_prompt_driver.py | 6 ++-- .../test_local_structure_run_driver.py | 2 +- tests/unit/events/test_event_listener.py | 2 +- .../structure/test_conversation_memory.py | 10 ++++-- tests/unit/structures/test_agent.py | 2 +- tests/unit/tasks/test_structure_run_task.py | 4 +-- tests/unit/tasks/test_tool_task.py | 2 +- tests/unit/tasks/test_toolkit_task.py | 6 ++-- tests/unit/utils/test_chat.py | 2 +- tests/utils/structure_tester.py | 2 +- 53 files changed, 106 insertions(+), 96 deletions(-) diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index c6c61f966..b6089c190 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -34,9 +34,9 @@ vector_path=MONGODB_VECTOR_PATH, ) -config.drivers = AzureOpenAiDriverConfig( +config.driver_config = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, - vector_store=mongo_driver, + vector_store_driver=mongo_driver, embedding_driver=embedding_driver, ) diff --git a/docs/examples/src/talk_to_a_video_1.py b/docs/examples/src/talk_to_a_video_1.py index 377e177a6..2748902a2 100644 --- a/docs/examples/src/talk_to_a_video_1.py +++ b/docs/examples/src/talk_to_a_video_1.py @@ -7,7 +7,7 @@ from griptape.config.drivers import GoogleDriverConfig from griptape.structures import Agent -config.drivers = GoogleDriverConfig() +config.driver_config = GoogleDriverConfig() video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index 75e2963d1..2705dcfad 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -7,12 +7,12 @@ from griptape.structures import Agent from griptape.tools import PromptSummaryTool, WebScraperTool -config.drivers = DriverConfig( +config.driver_config = DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ) -config.drivers = DriverConfig( +config.driver_config = DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ) diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py index d7f2b121e..ff59be794 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py @@ -7,7 +7,7 @@ from griptape.rules import Rule from griptape.structures import Agent -config.drivers = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) +config.driver_config = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) event_bus.add_event_listeners( [ EventListener( diff --git a/docs/griptape-framework/structures/src/config_1.py b/docs/griptape-framework/structures/src/config_1.py index df75488dc..e038130c2 100644 --- a/docs/griptape-framework/structures/src/config_1.py +++ b/docs/griptape-framework/structures/src/config_1.py @@ -2,6 +2,6 @@ from griptape.config.drivers import OpenAiDriverConfig from griptape.structures import Agent -config.drivers = OpenAiDriverConfig() +config.driver_config = OpenAiDriverConfig() agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_2.py b/docs/griptape-framework/structures/src/config_2.py index 6fcdedbc8..a187e8c06 100644 --- a/docs/griptape-framework/structures/src/config_2.py +++ b/docs/griptape-framework/structures/src/config_2.py @@ -4,7 +4,7 @@ from griptape.config.drivers import AzureOpenAiDriverConfig from griptape.structures import Agent -config.drivers = AzureOpenAiDriverConfig( +config.driver_config = AzureOpenAiDriverConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) diff --git a/docs/griptape-framework/structures/src/config_3.py b/docs/griptape-framework/structures/src/config_3.py index e4e33e379..4d08912f9 100644 --- a/docs/griptape-framework/structures/src/config_3.py +++ b/docs/griptape-framework/structures/src/config_3.py @@ -6,7 +6,7 @@ from griptape.config.drivers import AmazonBedrockDriverConfig from griptape.structures import Agent -config.drivers = AmazonBedrockDriverConfig( +config.driver_config = AmazonBedrockDriverConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], diff --git a/docs/griptape-framework/structures/src/config_4.py b/docs/griptape-framework/structures/src/config_4.py index 7ab5eee70..e97422388 100644 --- a/docs/griptape-framework/structures/src/config_4.py +++ b/docs/griptape-framework/structures/src/config_4.py @@ -2,6 +2,6 @@ from griptape.config.drivers import GoogleDriverConfig from griptape.structures import Agent -config.drivers = GoogleDriverConfig() +config.driver_config = GoogleDriverConfig() agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_5.py b/docs/griptape-framework/structures/src/config_5.py index bee5050c2..519b770df 100644 --- a/docs/griptape-framework/structures/src/config_5.py +++ b/docs/griptape-framework/structures/src/config_5.py @@ -2,6 +2,6 @@ from griptape.config.drivers import AnthropicDriverConfig from griptape.structures import Agent -config.drivers = AnthropicDriverConfig() +config.driver_config = AnthropicDriverConfig() agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_6.py b/docs/griptape-framework/structures/src/config_6.py index 569000180..c53d8c1b0 100644 --- a/docs/griptape-framework/structures/src/config_6.py +++ b/docs/griptape-framework/structures/src/config_6.py @@ -4,6 +4,6 @@ from griptape.config.drivers import CohereDriverConfig from griptape.structures import Agent -config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) +config.driver_config = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_7.py b/docs/griptape-framework/structures/src/config_7.py index 535f6e4d8..3f63d428e 100644 --- a/docs/griptape-framework/structures/src/config_7.py +++ b/docs/griptape-framework/structures/src/config_7.py @@ -5,7 +5,7 @@ from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent -config.drivers = DriverConfig( +config.driver_config = DriverConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], diff --git a/docs/griptape-framework/structures/src/config_8.py b/docs/griptape-framework/structures/src/config_8.py index 6bc87998c..909a2d5d9 100644 --- a/docs/griptape-framework/structures/src/config_8.py +++ b/docs/griptape-framework/structures/src/config_8.py @@ -13,6 +13,6 @@ } custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) -config.drivers = custom_config +config.driver_config = custom_config agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_logging.py b/docs/griptape-framework/structures/src/config_logging.py index 81645d5e2..4dceb6edd 100644 --- a/docs/griptape-framework/structures/src/config_logging.py +++ b/docs/griptape-framework/structures/src/config_logging.py @@ -5,9 +5,9 @@ from griptape.config.logging import TruncateLoggingFilter from griptape.structures import Agent -config.drivers = OpenAiDriverConfig() +config.driver_config = OpenAiDriverConfig() -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) logger.setLevel(logging.ERROR) logger.addFilter(TruncateLoggingFilter(max_log_length=100)) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 3eb75a921..1ee4538d7 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -11,11 +11,11 @@ from griptape.structures import Agent from griptape.tools import FileManagerTool, QueryTool, WebScraperTool -config.drivers = OpenAiDriverConfig( +config.driver_config = OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) -config.drivers = OpenAiDriverConfig( +config.driver_config = OpenAiDriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) diff --git a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py index b181d4ec7..1e82b2dd9 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py @@ -8,7 +8,7 @@ from griptape.tasks import ToolkitTask from griptape.tools import RestApiTool -config.drivers = DriverConfig( +config.driver_config = DriverConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), ) diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index 7ed00e445..e9f0bc687 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -14,9 +14,9 @@ @define(kw_only=True) class BaseConfig(SerializableMixin, ABC): - _logging: Optional[LoggingConfig] = field(alias="logging") - _drivers: Optional[BaseDriverConfig] = field(alias="drivers") + _logging_config: Optional[LoggingConfig] = field(alias="logging") + _driver_config: Optional[BaseDriverConfig] = field(alias="drivers") def reset(self) -> None: - self._logging = None - self._drivers = None + self._logging_config = None + self._driver_config = None diff --git a/griptape/config/config.py b/griptape/config/config.py index 11c2f9585..64f2575dd 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -14,29 +14,29 @@ @define(kw_only=True) class _Config(BaseConfig): - _logging: Optional[LoggingConfig] = field(default=None, alias="logging") - _drivers: Optional[BaseDriverConfig] = field(default=None, alias="drivers") + _logging_config: Optional[LoggingConfig] = field(default=None, alias="logging") + _driver_config: Optional[BaseDriverConfig] = field(default=None, alias="drivers") @property - def drivers(self) -> BaseDriverConfig: + def driver_config(self) -> BaseDriverConfig: """Lazily instantiates the drivers configuration to avoid client errors like missing API key.""" - if self._drivers is None: - self._drivers = OpenAiDriverConfig() - return self._drivers + if self._driver_config is None: + self._driver_config = OpenAiDriverConfig() + return self._driver_config - @drivers.setter - def drivers(self, drivers: BaseDriverConfig) -> None: - self._drivers = drivers + @driver_config.setter + def driver_config(self, drivers: BaseDriverConfig) -> None: + self._driver_config = drivers @property - def logging(self) -> LoggingConfig: - if self._logging is None: - self._logging = LoggingConfig() - return self._logging - - @logging.setter - def logging(self, logging: LoggingConfig) -> None: - self._logging = logging + def logging_config(self) -> LoggingConfig: + if self._logging_config is None: + self._logging_config = LoggingConfig() + return self._logging_config + + @logging_config.setter + def logging_config(self, logging: LoggingConfig) -> None: + self._logging_config = logging config = _Config() diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 6ab8abc3a..ee5739d81 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -8,7 +8,7 @@ @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: config.drivers.audio_transcription_driver), kw_only=True + default=Factory(lambda: config.driver_config.audio_transcription_driver), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 4d058f910..3036d8bf5 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: config.drivers.text_to_speech_driver), kw_only=True + default=Factory(lambda: config.driver_config.text_to_speech_driver), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index cb152cec7..0a28b65b3 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -18,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 583d5bc92..9f72f16be 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: config.drivers.image_generation_driver) + kw_only=True, default=Factory(lambda: config.driver_config.image_generation_driver) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index 393da833a..1b8fce277 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -14,7 +14,7 @@ @define class ImageQueryEngine: image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: config.drivers.image_query_driver), kw_only=True + default=Factory(lambda: config.driver_config.image_query_driver), kw_only=True ) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 9253122c0..d1bcdd3b8 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -18,7 +18,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver)) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 4fb2bcbc8..c04e9025a 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -18,7 +18,9 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field( + default=Factory(lambda: config.driver_config.vector_store_driver) + ) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 3808416f9..04f30ca82 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 8082b3bf4..86431122a 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: config.drivers.conversation_memory_driver), kw_only=True + default=Factory(lambda: config.driver_config.conversation_memory_driver), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = config.drivers.prompt_driver + prompt_driver = config.driver_config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index ed2807216..736891d90 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt_driver)) + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.driver_config.prompt_driver)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index dd33b0023..5eb3ab734 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -14,7 +14,9 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store_driver)) + vector_store_driver: BaseVectorStoreDriver = field( + default=Factory(lambda: config.driver_config.vector_store_driver) + ) def can_store(self, artifact: BaseArtifact) -> bool: return isinstance(artifact, TextArtifact) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 549e7c665..2c4edfc7d 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -23,8 +23,8 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: bool = field(default=Factory(lambda: config.drivers.prompt_driver.stream), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) + stream: bool = field(default=Factory(lambda: config.driver_config.prompt_driver.stream), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 83ffc2081..da6f214eb 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 00774e0a2..b24ec3f3c 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -9,7 +9,7 @@ from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index 8a470bb85..fc94ed5d4 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -11,7 +11,7 @@ from griptape.mixins import RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index f94ff8de2..9c226256b 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -18,7 +18,7 @@ from griptape.artifacts import MediaArtifact -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index c688a1129..52b68e4e0 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -12,7 +12,7 @@ from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index f5a772e48..d80767793 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -18,7 +18,7 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 1c9dfc023..0a53b9fcd 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -12,7 +12,7 @@ from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 0f3acc8c6..719ae77bc 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -15,12 +15,12 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver -logger = logging.getLogger(config.logging.logger_name) +logger = logging.getLogger(config.logging_config.logger_name) @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index eed4c9522..70cc9f747 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -25,7 +25,9 @@ class QueryTool(BaseTool, RuleMixin): lambda self: RagEngine( response_stage=ResponseRagStage( response_modules=[ - PromptResponseRagModule(prompt_driver=config.drivers.prompt_driver, rulesets=self.rulesets) + PromptResponseRagModule( + prompt_driver=config.driver_config.prompt_driver, rulesets=self.rulesets + ) ], ), ), diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index f34a97504..f30e9f1cd 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -41,8 +41,8 @@ def start(self) -> None: from griptape.config import config # Hide Griptape's logging output except for errors - old_logger_level = logging.getLogger(config.logging.logger_name).getEffectiveLevel() - logging.getLogger(config.logging.logger_name).setLevel(self.logger_level) + old_logger_level = logging.getLogger(config.logging_config.logger_name).getEffectiveLevel() + logging.getLogger(config.logging_config.logger_name).setLevel(self.logger_level) if self.intro_text: self.output_fn(self.intro_text) @@ -53,7 +53,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if config.drivers.prompt_driver.stream: + if config.driver_config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) @@ -65,4 +65,4 @@ def start(self) -> None: self.output_fn(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}") # Restore the original logger level - logging.getLogger(config.logging.logger_name).setLevel(old_logger_level) + logging.getLogger(config.logging_config.logger_name).setLevel(old_logger_level) diff --git a/tests/unit/config/logging/test_newline_logging_filter.py b/tests/unit/config/logging/test_newline_logging_filter.py index d5b05e323..89166dd40 100644 --- a/tests/unit/config/logging/test_newline_logging_filter.py +++ b/tests/unit/config/logging/test_newline_logging_filter.py @@ -10,7 +10,7 @@ class TestNewlineLoggingFilter: def test_filter(self): # use the filter in an Agent - logger = logging.getLogger(config.logging.logger_name) + logger = logging.getLogger(config.logging_config.logger_name) logger.addFilter(NewlineLoggingFilter(replace_str="$$$")) agent = Agent() # use a context manager to capture the stdout diff --git a/tests/unit/config/logging/test_truncate_logging_filter.py b/tests/unit/config/logging/test_truncate_logging_filter.py index fc0aa1c47..a9387b52b 100644 --- a/tests/unit/config/logging/test_truncate_logging_filter.py +++ b/tests/unit/config/logging/test_truncate_logging_filter.py @@ -10,7 +10,7 @@ class TestTruncateLoggingFilter: def test_filter(self): # use the filter in an Agent - logger = logging.getLogger(config.logging.logger_name) + logger = logging.getLogger(config.logging_config.logger_name) logger.addFilter(TruncateLoggingFilter(max_log_length=0)) agent = Agent() # use a context manager to capture the stdout diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py index 04d5586d2..4ed75d325 100644 --- a/tests/unit/config/test_config.py +++ b/tests/unit/config/test_config.py @@ -9,18 +9,18 @@ def test_init(self): from griptape.config import config from griptape.config.logging import LoggingConfig - assert isinstance(config.drivers, OpenAiDriverConfig) - assert isinstance(config.logging, LoggingConfig) + assert isinstance(config.driver_config, OpenAiDriverConfig) + assert isinstance(config.logging_config, LoggingConfig) @pytest.mark.skip_mock_config() def test_lazy_init(self): from griptape.config import config - assert config._drivers is None - assert config._logging is None + assert config._driver_config is None + assert config._logging_config is None - assert config.drivers is not None - assert config.logging is not None + assert config.driver_config is not None + assert config.logging_config is not None - assert config._drivers is not None - assert config._logging is not None + assert config._driver_config is not None + assert config._logging_config is not None diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index db881bc20..e2eaabe1f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -26,7 +26,7 @@ def mock_config(request): return - config.drivers = MockDriverConfig() + config.driver_config = MockDriverConfig() yield config diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index c575a01d3..3e0b0ffc8 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.driver_config.prompt_driver = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.drivers.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.driver_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -46,7 +46,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.driver_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index b2e9c069b..2090be39c 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index e66b8816b..0078ebc34 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(stream=True) + mock_config.driver_config.prompt_driver = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 06e54e6c4..b7137524e 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,9 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + mock_config.driver_config.prompt_driver = MockPromptDriver( + tokenizer=MockTokenizer(model="foo", max_input_tokens=0) + ) agent = Agent() memory = ConversationMemory( autoprune=True, @@ -119,7 +121,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver( + mock_config.driver_config.prompt_driver = MockPromptDriver( tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) ) agent = Agent() @@ -145,7 +147,9 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.drivers.prompt_driver = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + mock_config.driver_config.prompt_driver = MockPromptDriver( + tokenizer=MockTokenizer(model="foo", max_input_tokens=160) + ) agent = Agent() memory = ConversationMemory( autoprune=True, diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 36a73db74..d90f2f8ba 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -222,7 +222,7 @@ def test_task_memory_defaults(self, mock_config): storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.vector_store_driver.embedding_driver == mock_config.drivers.embedding_driver + assert storage.vector_store_driver.embedding_driver == mock_config.driver_config.embedding_driver def finished_tasks(self): task = PromptTask("test prompt") diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 2c0dc1b28..fe434a281 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.drivers.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 70ab05e12..9ba1df731 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,7 +168,7 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.drivers.prompt_driver = MockPromptDriver( + mock_config.driver_config.prompt_driver = MockPromptDriver( mock_output=f"```python foo bar\n{json.dumps(output_dict)}" ) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 15f5a59b1..6837fca78 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.driver_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.drivers.prompt_driver.mock_output = output + mock_config.driver_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.drivers.prompt_driver.mock_output = output + mock_config.driver_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 5c73a6845..ff728718a 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -37,7 +37,7 @@ def test_chat_logger_level(self, mock_input): chat = Chat(agent) - logger = logging.getLogger(config.logging.logger_name) + logger = logging.getLogger(config.logging_config.logger_name) logger.setLevel(logging.DEBUG) assert logger.getEffectiveLevel() == logging.DEBUG diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 317fc0e84..7de57d85c 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -230,7 +230,7 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: def verify_structure_output(self, structure) -> dict: from griptape.config import config - config.drivers.prompt_driver = AzureOpenAiChatPromptDriver( + config.driver_config.prompt_driver = AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], From 417b442b461497a9a2d03757e7cb7c306eee4e6e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 19 Aug 2024 11:21:08 -0700 Subject: [PATCH 3/3] Use lazy decorator --- griptape/config/config.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/griptape/config/config.py b/griptape/config/config.py index 64f2575dd..5f5dbb5db 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -4,6 +4,8 @@ from attrs import define, field +from griptape.utils.decorators import lazy_property + from .base_config import BaseConfig from .drivers.openai_driver_config import OpenAiDriverConfig from .logging.logging_config import LoggingConfig @@ -17,26 +19,13 @@ class _Config(BaseConfig): _logging_config: Optional[LoggingConfig] = field(default=None, alias="logging") _driver_config: Optional[BaseDriverConfig] = field(default=None, alias="drivers") - @property + @lazy_property() def driver_config(self) -> BaseDriverConfig: - """Lazily instantiates the drivers configuration to avoid client errors like missing API key.""" - if self._driver_config is None: - self._driver_config = OpenAiDriverConfig() - return self._driver_config - - @driver_config.setter - def driver_config(self, drivers: BaseDriverConfig) -> None: - self._driver_config = drivers + return OpenAiDriverConfig() - @property + @lazy_property() def logging_config(self) -> LoggingConfig: - if self._logging_config is None: - self._logging_config = LoggingConfig() - return self._logging_config - - @logging_config.setter - def logging_config(self, logging: LoggingConfig) -> None: - self._logging_config = logging + return LoggingConfig() config = _Config()