diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c5ad46f7..1b7b08f36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,14 +13,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Global config, `griptape.config.config`, for setting global configuration defaults. - Unique name generation for all `RagEngine` modules. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. +- **BREAKING**: Removed `Pipeline.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.stream` and `Workflow.stream`. `Agent.stream` has not been removed. +- **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. +- **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. +- **BREAKING**: Removed `BaseStructureConfig.merge_config`. +- **BREAKING**: Renamed `StructureConfig` to `DriverConfig`, and renamed fields accordingly. - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. - **BREAKING**: Removed before and after response modules from `ResponseRagStage`. - **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. +- Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.1] - 2024-08-02 diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index dd2bb3076..118684d37 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -1,6 +1,6 @@ import os -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig, config from griptape.drivers import AzureMongoDbVectorStoreDriver, AzureOpenAiEmbeddingDriver from griptape.structures import Agent from griptape.tools import TaskMemoryClient, WebScraper @@ -33,17 +33,16 @@ vector_path=MONGODB_VECTOR_PATH, ) -config = AzureOpenAiStructureConfig( +config.drivers = AzureOpenAiDriverConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, - vector_store_driver=mongo_driver, - embedding_driver=embedding_driver, + vector_store=mongo_driver, + embedding=embedding_driver, ) loader = Agent( tools=[ WebScraper(off_prompt=True), ], - config=config, ) asker = Agent( tools=[ @@ -51,7 +50,6 @@ ], meta_memory=loader.meta_memory, task_memory=loader.task_memory, - config=config, ) if __name__ == "__main__": diff --git a/docs/examples/src/talk_to_a_video_1.py b/docs/examples/src/talk_to_a_video_1.py index d237b5191..3538c9071 100644 --- a/docs/examples/src/talk_to_a_video_1.py +++ b/docs/examples/src/talk_to_a_video_1.py @@ -3,9 +3,11 @@ import google.generativeai as genai from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig, config from griptape.structures import Agent +config.drivers = GoogleDriverConfig() + video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": time.sleep(2) @@ -15,11 +17,10 @@ raise ValueError(video_file.state.name) agent = Agent( - config=GoogleStructureConfig(), input=[ GenericArtifact(video_file), TextArtifact("Answer this question regarding the video: {{ args[0] }}"), - ], + ] ) agent.run("Are there any scenes that show a character with earings?") diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 4209c2821..68f40f09e 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -128,7 +128,7 @@ The [CohereEmbeddingDriver](../../reference/griptape/drivers/embedding/cohere_em ``` ### Override Default Structure Embedding Driver -Here is how you can override the Embedding Driver that is used by default in Structures. +Here is how you can override the Embedding Driver that is used by default in Structures. ```python --8<-- "docs/griptape-framework/drivers/src/embedding_drivers_10.py" diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index ab5be55f3..54230b999 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -166,7 +166,6 @@ The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prom Amazon Sagemaker Jumpstart provides a wide range of models with varying capabilities. This Driver has been primarily _chat-optimized_ models that have a [Huggingface Chat Template](https://huggingface.co/docs/transformers/en/chat_templating) available. If your model does not fit this use-case, we suggest sub-classing [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) and overriding the `_to_model_input` and `_to_model_params` methods. - ```python --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_14.py" diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index 40d08349f..3ef816b29 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -1,4 +1,4 @@ -from griptape.config import StructureConfig +from griptape.config import DriverConfig, config from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, @@ -6,12 +6,13 @@ from griptape.structures import Agent from griptape.tools import TaskMemoryClient, WebScraper +config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver(model="gpt-4o"), + embedding=VoyageAiEmbeddingDriver(), +) + agent = Agent( tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), - embedding_driver=VoyageAiEmbeddingDriver(), - ), ) agent.run("based on https://www.griptape.ai/, tell me what Griptape is") diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py index dfac295b7..a70e05d79 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py @@ -1,11 +1,12 @@ import os -from griptape.config import StructureConfig +from griptape.config import DriverConfig, config from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver from griptape.events import EventListener, FinishStructureRunEvent, event_bus from griptape.rules import Rule from griptape.structures import Agent +config.drivers = DriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) event_bus.add_event_listeners( [ EventListener( @@ -20,7 +21,6 @@ agent = Agent( rules=[Rule(value="You will be provided with a text, and your task is to extract the airport codes from it.")], - config=StructureConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)), ) agent.run("I want to fly from Orlando to Boston") diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_1.py b/docs/griptape-framework/drivers/src/prompt_drivers_1.py index 978435f2d..ab5273228 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_1.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_1.py @@ -1,12 +1,9 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), - ), + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[Rule(value="Output only the sentiment.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_10.py b/docs/griptape-framework/drivers/src/prompt_drivers_10.py index 02f083570..04e2acb35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_10.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_10.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OllamaPromptDriver from griptape.structures import Agent from griptape.tools import Calculator agent = Agent( - config=StructureConfig( - prompt_driver=OllamaPromptDriver( - model="llama3.1", - ), + prompt_driver=OllamaPromptDriver( + model="llama3.1", ), tools=[Calculator()], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_11.py b/docs/griptape-framework/drivers/src/prompt_drivers_11.py index 1c81c4785..9e838473c 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_11.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_11.py @@ -1,16 +1,13 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="HuggingFaceH4/zephyr-7b-beta", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ) + prompt_driver=HuggingFaceHubPromptDriver( + model="HuggingFaceH4/zephyr-7b-beta", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_12.py b/docs/griptape-framework/drivers/src/prompt_drivers_12.py index d6f59f96e..d555c32c9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_12.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_12.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import HuggingFaceHubPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFaceHubPromptDriver( - model="http://127.0.0.1:8080", - api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - ), + prompt_driver=HuggingFaceHubPromptDriver( + model="http://127.0.0.1:8080", + api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], ), ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_13.py b/docs/griptape-framework/drivers/src/prompt_drivers_13.py index e4fe5208c..d3ddd9093 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_13.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_13.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=HuggingFacePipelinePromptDriver( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - ) + prompt_driver=HuggingFacePipelinePromptDriver( + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ), rulesets=[ Ruleset( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_14.py b/docs/griptape-framework/drivers/src/prompt_drivers_14.py index 85bd5216e..228a5f9b2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_14.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_14.py @@ -1,17 +1,14 @@ import os -from griptape.config import StructureConfig from griptape.drivers import ( AmazonSageMakerJumpstartPromptDriver, ) from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonSageMakerJumpstartPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], - model="meta-llama/Meta-Llama-3-8B-Instruct", - ) + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], + model="meta-llama/Meta-Llama-3-8B-Instruct", ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_3.py b/docs/griptape-framework/drivers/src/prompt_drivers_3.py index b92596aca..8e85ce887 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_3.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_3.py @@ -1,19 +1,16 @@ import os -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - api_key=os.environ["OPENAI_API_KEY"], - temperature=0.1, - model="gpt-4o", - response_format="json_object", - seed=42, - ) + prompt_driver=OpenAiChatPromptDriver( + api_key=os.environ["OPENAI_API_KEY"], + temperature=0.1, + model="gpt-4o", + response_format="json_object", + seed=42, ), input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", rules=[Rule(value='Write your output in json with a single key called "css_code".')], diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_4.py b/docs/griptape-framework/drivers/src/prompt_drivers_4.py index b024638b7..bcafb40de 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_4.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_4.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver( - base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True - ) + prompt_driver=OpenAiChatPromptDriver( + base_url="http://127.0.0.1:1234/v1", model="lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF", stream=True ), rules=[Rule(value="You are a helpful coding assistant.")], ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_5.py b/docs/griptape-framework/drivers/src/prompt_drivers_5.py index ffe9a4e0a..76301d8d9 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_5.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_5.py @@ -1,18 +1,15 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AzureOpenAiChatPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-3.5-turbo", - azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - ) + prompt_driver=AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-3.5-turbo", + azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], ), rules=[ Rule( diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_6.py b/docs/griptape-framework/drivers/src/prompt_drivers_6.py index 2bd1b00fb..5e4d226a6 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_6.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_6.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import CoherePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=CoherePromptDriver( - model="command-r", - api_key=os.environ["COHERE_API_KEY"], - ) + prompt_driver=CoherePromptDriver( + model="command-r", + api_key=os.environ["COHERE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_7.py b/docs/griptape-framework/drivers/src/prompt_drivers_7.py index dd1c15370..23f3d0c35 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_7.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_7.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-opus-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) + prompt_driver=AnthropicPromptDriver( + model="claude-3-opus-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_8.py b/docs/griptape-framework/drivers/src/prompt_drivers_8.py index 1bbf2848c..b6a1c109e 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_8.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_8.py @@ -1,15 +1,12 @@ import os -from griptape.config import StructureConfig from griptape.drivers import GooglePromptDriver from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=GooglePromptDriver( - model="gemini-pro", - api_key=os.environ["GOOGLE_API_KEY"], - ) + prompt_driver=GooglePromptDriver( + model="gemini-pro", + api_key=os.environ["GOOGLE_API_KEY"], ) ) diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_9.py b/docs/griptape-framework/drivers/src/prompt_drivers_9.py index cdd0db82d..992dbecd2 100644 --- a/docs/griptape-framework/drivers/src/prompt_drivers_9.py +++ b/docs/griptape-framework/drivers/src/prompt_drivers_9.py @@ -1,13 +1,10 @@ -from griptape.config import StructureConfig from griptape.drivers import AmazonBedrockPromptDriver from griptape.rules import Rule from griptape.structures import Agent agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - ) + prompt_driver=AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", ), rules=[ Rule( diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index 31790b494..bae8b8224 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,6 +1,5 @@ from typing import cast -from griptape.config import OpenAiStructureConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.events import CompletionChunkEvent, EventListener, event_bus from griptape.structures import Pipeline @@ -16,13 +15,11 @@ ] ) -pipeline = Pipeline( - config=OpenAiStructureConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True)), -) - +pipeline = Pipeline() pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), tools=[WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=False)], ) ) diff --git a/docs/griptape-framework/misc/src/events_4.py b/docs/griptape-framework/misc/src/events_4.py index 3816b110e..f5523cb11 100644 --- a/docs/griptape-framework/misc/src/events_4.py +++ b/docs/griptape-framework/misc/src/events_4.py @@ -4,7 +4,6 @@ from griptape.utils import Stream pipeline = Pipeline() -pipeline.config.prompt_driver.stream = True pipeline.add_tasks( ToolkitTask( "Based on https://griptape.ai, tell me what griptape is.", diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index e10b6cc73..89399f60c 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -5,16 +5,15 @@ search: ## Overview -The [StructureConfig](../../reference/griptape/config/structure_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. ### Premade Configs -Griptape provides predefined [StructureConfig](../../reference/griptape/config/structure_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. +Griptape provides predefined [DriverConfig](../../reference/griptape/config/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. #### OpenAI -The [OpenAI Structure Config](../../reference/griptape/config/openai_structure_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. - +The [OpenAI Driver config](../../reference/griptape/config/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. ```python --8<-- "docs/griptape-framework/structures/src/config_1.py" @@ -22,22 +21,21 @@ The [OpenAI Structure Config](../../reference/griptape/config/openai_structure_c #### Azure OpenAI -The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_structure_config.md) provides default Drivers for Azure's OpenAI APIs. - +The [Azure OpenAI Driver config](../../reference/griptape/config/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. ```python --8<-- "docs/griptape-framework/structures/src/config_2.py" ``` #### Amazon Bedrock -The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_structure_config.md) provides default Drivers for Amazon Bedrock's APIs. +The [Amazon Bedrock Driver config](../../reference/griptape/config/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. ```python --8<-- "docs/griptape-framework/structures/src/config_3.py" ``` #### Google -The [Google Structure Config](../../reference/griptape/config/google_structure_config.md) provides default Drivers for Google's Gemini APIs. +The [Google Driver config](../../reference/griptape/config/google_driver_config.md) provides default Drivers for Google's Gemini APIs. ```python --8<-- "docs/griptape-framework/structures/src/config_4.py" @@ -45,22 +43,20 @@ The [Google Structure Config](../../reference/griptape/config/google_structure_c #### Anthropic -The [Anthropic Structure Config](../../reference/griptape/config/anthropic_structure_config.md) provides default Drivers for Anthropic's APIs. +The [Anthropic Driver config](../../reference/griptape/config/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. !!! info Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. - The `AnthropicStructureConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). + The `AnthropicDriverConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). - ```python --8<-- "docs/griptape-framework/structures/src/config_5.py" ``` #### Cohere -The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_config.md) provides default Drivers for Cohere's APIs. - +The [Cohere Driver config](../../reference/griptape/config/cohere_driver_config.md) provides default Drivers for Cohere's APIs. ```python --8<-- "docs/griptape-framework/structures/src/config_6.py" @@ -68,8 +64,8 @@ The [Cohere Structure Config](../../reference/griptape/config/cohere_structure_c ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers. -The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +You can create your own [DriverConfig](../../reference/griptape/config/driver_config.md) by overriding relevant Drivers. +The [DriverConfig](../../reference/griptape/config/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. ```python @@ -78,9 +74,6 @@ This approach ensures that you are informed through clear error messages if you ### Loading/Saving Configs -Configuration classes in Griptape offer utility methods for loading, saving, and merging configurations, streamlining the management of complex setups. - ```python --8<-- "docs/griptape-framework/structures/src/config_8.py" ``` - diff --git a/docs/griptape-framework/structures/src/config_1.py b/docs/griptape-framework/structures/src/config_1.py index 613852193..0c7a5ed9e 100644 --- a/docs/griptape-framework/structures/src/config_1.py +++ b/docs/griptape-framework/structures/src/config_1.py @@ -1,6 +1,6 @@ -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig, config from griptape.structures import Agent -agent = Agent(config=OpenAiStructureConfig()) +config.drivers = OpenAiDriverConfig() -agent = Agent() # This is equivalent to the above +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_2.py b/docs/griptape-framework/structures/src/config_2.py index bf6b6223b..a5f8efbbe 100644 --- a/docs/griptape-framework/structures/src/config_2.py +++ b/docs/griptape-framework/structures/src/config_2.py @@ -1,16 +1,10 @@ import os -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig, config from griptape.structures import Agent -agent = Agent( - config=AzureOpenAiStructureConfig( - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] - ).merge_config( - { - "image_query_driver": { - "azure_deployment": "gpt-4o", - }, - } - ), +config.drivers = AzureOpenAiDriverConfig( + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_3.py b/docs/griptape-framework/structures/src/config_3.py index acb55c007..6b3f51a76 100644 --- a/docs/griptape-framework/structures/src/config_3.py +++ b/docs/griptape-framework/structures/src/config_3.py @@ -2,15 +2,15 @@ import boto3 -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig, config from griptape.structures import Agent -agent = Agent( - config=AmazonBedrockStructureConfig( - session=boto3.Session( - region_name=os.environ["AWS_DEFAULT_REGION"], - aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - ) +config.drivers = AmazonBedrockDriverConfig( + session=boto3.Session( + region_name=os.environ["AWS_DEFAULT_REGION"], + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], ) ) + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_4.py b/docs/griptape-framework/structures/src/config_4.py index 88d388161..5362b8c6b 100644 --- a/docs/griptape-framework/structures/src/config_4.py +++ b/docs/griptape-framework/structures/src/config_4.py @@ -1,4 +1,6 @@ -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig, config from griptape.structures import Agent -agent = Agent(config=GoogleStructureConfig()) +config.drivers = GoogleDriverConfig() + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_5.py b/docs/griptape-framework/structures/src/config_5.py index f0da565a7..4f787a922 100644 --- a/docs/griptape-framework/structures/src/config_5.py +++ b/docs/griptape-framework/structures/src/config_5.py @@ -1,4 +1,6 @@ -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig, config from griptape.structures import Agent -agent = Agent(config=AnthropicStructureConfig()) +config.drivers = AnthropicDriverConfig() + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_6.py b/docs/griptape-framework/structures/src/config_6.py index 4e733e9d4..c26502401 100644 --- a/docs/griptape-framework/structures/src/config_6.py +++ b/docs/griptape-framework/structures/src/config_6.py @@ -1,6 +1,8 @@ import os -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig, config from griptape.structures import Agent -agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"])) +config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_7.py b/docs/griptape-framework/structures/src/config_7.py index 79a81d878..6b285f0e5 100644 --- a/docs/griptape-framework/structures/src/config_7.py +++ b/docs/griptape-framework/structures/src/config_7.py @@ -1,14 +1,15 @@ import os -from griptape.config import StructureConfig +from griptape.config import DriverConfig, config from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent -agent = Agent( - config=StructureConfig( - prompt_driver=AnthropicPromptDriver( - model="claude-3-sonnet-20240229", - api_key=os.environ["ANTHROPIC_API_KEY"], - ) - ), +config.drivers = DriverConfig( + prompt=AnthropicPromptDriver( + model="claude-3-sonnet-20240229", + api_key=os.environ["ANTHROPIC_API_KEY"], + ) ) + + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_8.py b/docs/griptape-framework/structures/src/config_8.py index 9bb1d01c9..4f23e3eaa 100644 --- a/docs/griptape-framework/structures/src/config_8.py +++ b/docs/griptape-framework/structures/src/config_8.py @@ -1,28 +1,17 @@ -from griptape.config import AmazonBedrockStructureConfig -from griptape.drivers import AmazonBedrockCohereEmbeddingDriver +from griptape.config import AmazonBedrockDriverConfig, config from griptape.structures import Agent -custom_config = AmazonBedrockStructureConfig() -custom_config.embedding_driver = AmazonBedrockCohereEmbeddingDriver() -custom_config.merge_config( - { - "embedding_driver": { - "base_url": None, - "model": "text-embedding-3-small", - "organization": None, - "type": "OpenAiEmbeddingDriver", - }, - } -) -serialized_config = custom_config.to_json() -deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config) +custom_config = AmazonBedrockDriverConfig() +dict_config = custom_config.to_dict() +# Use OpenAi for embeddings +dict_config["embedding"] = { + "base_url": None, + "model": "text-embedding-3-small", + "organization": None, + "type": "OpenAiEmbeddingDriver", +} +custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) -agent = Agent( - config=deserialized_config.merge_config( - { - "prompt_driver": { - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - }, - } - ), -) +config.drivers = custom_config + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 7bbc5614a..fb5c3eabb 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -1,6 +1,7 @@ from griptape.artifacts import TextArtifact from griptape.config import ( - OpenAiStructureConfig, + OpenAiDriverConfig, + config, ) from griptape.drivers import ( LocalVectorStoreDriver, @@ -15,12 +16,13 @@ from griptape.structures import Agent from griptape.tools import FileManager, TaskMemoryClient, WebScraper +config.drivers = OpenAiDriverConfig( + prompt=OpenAiChatPromptDriver(model="gpt-4"), +) + vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) agent = Agent( - config=OpenAiStructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), - ), task_memory=TaskMemory( artifact_storages={ TextArtifact: TextArtifactStorage( diff --git a/docs/griptape-tools/official-tools/rest-api-client.md b/docs/griptape-tools/official-tools/rest-api-client.md index 447f656d5..8ec4804c2 100644 --- a/docs/griptape-tools/official-tools/rest-api-client.md +++ b/docs/griptape-tools/official-tools/rest-api-client.md @@ -6,7 +6,7 @@ The [RestApiClient](../../reference/griptape/tools/rest_api_client/tool.md) tool ### Example The following example is built using [https://jsonplaceholder.typicode.com/guide/](https://jsonplaceholder.typicode.com/guide/). - + ```python --8<-- "docs/griptape-tools/official-tools/src/rest_api_client_1.py" ``` diff --git a/docs/griptape-tools/official-tools/src/rest_api_client_1.py b/docs/griptape-tools/official-tools/src/rest_api_client_1.py index 01373de00..026874283 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_client_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_client_1.py @@ -1,12 +1,16 @@ from json import dumps -from griptape.config import StructureConfig +from griptape.config import DriverConfig, config from griptape.drivers import OpenAiChatPromptDriver from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiClient +config.drivers = DriverConfig( + prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), +) + posts_client = RestApiClient( base_url="https://jsonplaceholder.typicode.com", path="posts", @@ -108,9 +112,6 @@ pipeline = Pipeline( conversation_memory=ConversationMemory(), - config=StructureConfig( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), - ), ) pipeline.add_tasks( diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 541eb0db0..b242d80a7 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -1,24 +1,26 @@ from .base_config import BaseConfig -from .base_structure_config import BaseStructureConfig +from .base_driver_config import BaseDriverConfig -from .structure_config import StructureConfig -from .openai_structure_config import OpenAiStructureConfig -from .azure_openai_structure_config import AzureOpenAiStructureConfig -from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig -from .anthropic_structure_config import AnthropicStructureConfig -from .google_structure_config import GoogleStructureConfig -from .cohere_structure_config import CohereStructureConfig +from .driver_config import DriverConfig +from .openai_driver_config import OpenAiDriverConfig +from .azure_openai_driver_config import AzureOpenAiDriverConfig +from .amazon_bedrock_driver_config import AmazonBedrockDriverConfig +from .anthropic_driver_config import AnthropicDriverConfig +from .google_driver_config import GoogleDriverConfig +from .cohere_driver_config import CohereDriverConfig +from .config import config __all__ = [ "BaseConfig", - "BaseStructureConfig", - "StructureConfig", - "OpenAiStructureConfig", - "AzureOpenAiStructureConfig", - "AmazonBedrockStructureConfig", - "AnthropicStructureConfig", - "GoogleStructureConfig", - "CohereStructureConfig", + "BaseDriverConfig", + "DriverConfig", + "OpenAiDriverConfig", + "AzureOpenAiDriverConfig", + "AmazonBedrockDriverConfig", + "AnthropicDriverConfig", + "GoogleDriverConfig", + "CohereDriverConfig", + "config", ] diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_driver_config.py similarity index 84% rename from griptape/config/amazon_bedrock_structure_config.py rename to griptape/config/amazon_bedrock_driver_config.py index 3ad7f8f48..a07300638 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -25,14 +25,14 @@ @define -class AmazonBedrockStructureConfig(StructureConfig): +class AmazonBedrockDriverConfig(DriverConfig): session: boto3.Session = field( default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True, metadata={"serializable": False}, ) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AmazonBedrockPromptDriver( session=self.session, @@ -43,7 +43,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1"), takes_self=True, @@ -51,7 +51,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageGenerationDriver( session=self.session, @@ -63,7 +63,7 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageGenerationDriver = field( + image_query: BaseImageGenerationDriver = field( default=Factory( lambda self: AmazonBedrockImageQueryDriver( session=self.session, @@ -75,8 +75,8 @@ class AmazonBedrockStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_driver_config.py similarity index 76% rename from griptape/config/anthropic_structure_config.py rename to griptape/config/anthropic_driver_config.py index 1bb5bf49b..642a3fced 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AnthropicImageQueryDriver, AnthropicPromptDriver, @@ -14,25 +14,25 @@ @define -class AnthropicStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class AnthropicDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")), ), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_driver_config.py similarity index 88% rename from griptape/config/azure_openai_structure_config.py rename to griptape/config/azure_openai_driver_config.py index ce0303e34..c987a31b5 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_driver_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, @@ -20,8 +20,8 @@ @define -class AzureOpenAiStructureConfig(StructureConfig): - """Azure OpenAI Structure Configuration. +class AzureOpenAiDriverConfig(DriverConfig): + """Azure OpenAI Driver Configuration. Attributes: azure_endpoint: The endpoint for the Azure OpenAI instance. @@ -43,7 +43,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": False}, ) api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory( lambda self: AzureOpenAiChatPromptDriver( model="gpt-4o", @@ -57,7 +57,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory( lambda self: AzureOpenAiImageGenerationDriver( model="dall-e-2", @@ -72,7 +72,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory( lambda self: AzureOpenAiImageQueryDriver( model="gpt-4o", @@ -86,7 +86,7 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", @@ -100,8 +100,8 @@ class AzureOpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), metadata={"serializable": True}, kw_only=True, ) diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index 241efadcd..9209aa4a4 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -4,6 +4,11 @@ from griptape.mixins.serializable_mixin import SerializableMixin +from .base_driver_config import BaseDriverConfig +from .logging_config import LoggingConfig -@define -class BaseConfig(SerializableMixin, ABC): ... + +@define(kw_only=True) +class BaseConfig(SerializableMixin, ABC): + drivers: BaseDriverConfig + logging: LoggingConfig diff --git a/griptape/config/base_driver_config.py b/griptape/config/base_driver_config.py new file mode 100644 index 000000000..df32d382e --- /dev/null +++ b/griptape/config/base_driver_config.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.mixins import SerializableMixin + +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + + +@define +class BaseDriverConfig(ABC, SerializableMixin): + prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) + image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) + image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) + embedding: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) + vector_store: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) + conversation_memory: Optional[BaseConversationMemoryDriver] = field( + default=None, + kw_only=True, + metadata={"serializable": True}, + ) + text_to_speech: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) + audio_transcription: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py deleted file mode 100644 index c2aa82d7e..000000000 --- a/griptape/config/base_structure_config.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.config import BaseConfig -from griptape.utils import dict_merge - -if TYPE_CHECKING: - from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, - ) - - -@define -class BaseStructureConfig(BaseConfig, ABC): - prompt_driver: BasePromptDriver = field(kw_only=True, metadata={"serializable": True}) - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True}) - image_query_driver: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True}) - embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) - vector_store_driver: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True}) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( - default=None, - kw_only=True, - metadata={"serializable": True}, - ) - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) - - def merge_config(self, config: dict) -> BaseStructureConfig: - base_config = self.to_dict() - merged_config = dict_merge(base_config, config) - - return BaseStructureConfig.from_dict(merged_config) diff --git a/griptape/config/cohere_structure_config.py b/griptape/config/cohere_driver_config.py similarity index 76% rename from griptape/config/cohere_structure_config.py rename to griptape/config/cohere_driver_config.py index 2e896b9b0..7195f550f 100644 --- a/griptape/config/cohere_structure_config.py +++ b/griptape/config/cohere_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,15 +12,15 @@ @define -class CohereStructureConfig(StructureConfig): +class CohereDriverConfig(DriverConfig): api_key: str = field(metadata={"serializable": False}, kw_only=True) - prompt_driver: BasePromptDriver = field( + prompt: BasePromptDriver = field( default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True), metadata={"serializable": True}, kw_only=True, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory( lambda self: CohereEmbeddingDriver( model="embed-english-v3.0", @@ -32,8 +32,8 @@ class CohereStructureConfig(StructureConfig): metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + vector_store: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/config/config.py b/griptape/config/config.py new file mode 100644 index 000000000..97d501abb --- /dev/null +++ b/griptape/config/config.py @@ -0,0 +1,15 @@ +from attrs import Factory, define, field + +from .base_config import BaseConfig +from .base_driver_config import BaseDriverConfig +from .logging_config import LoggingConfig +from .openai_driver_config import OpenAiDriverConfig + + +@define +class _Config(BaseConfig): + drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) + logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) + + +config = _Config() diff --git a/griptape/config/structure_config.py b/griptape/config/driver_config.py similarity index 60% rename from griptape/config/structure_config.py rename to griptape/config/driver_config.py index ef95012ce..325591258 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/driver_config.py @@ -1,19 +1,11 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field -from griptape.config import BaseStructureConfig +from griptape.config import BaseDriverConfig from griptape.drivers import ( - BaseAudioTranscriptionDriver, - BaseConversationMemoryDriver, - BaseEmbeddingDriver, - BaseImageGenerationDriver, - BaseImageQueryDriver, - BasePromptDriver, - BaseTextToSpeechDriver, - BaseVectorStoreDriver, DummyAudioTranscriptionDriver, DummyEmbeddingDriver, DummyImageGenerationDriver, @@ -23,45 +15,57 @@ DummyVectorStoreDriver, ) +if TYPE_CHECKING: + from griptape.drivers import ( + BaseAudioTranscriptionDriver, + BaseConversationMemoryDriver, + BaseEmbeddingDriver, + BaseImageGenerationDriver, + BaseImageQueryDriver, + BasePromptDriver, + BaseTextToSpeechDriver, + BaseVectorStoreDriver, + ) + @define -class StructureConfig(BaseStructureConfig): - prompt_driver: BasePromptDriver = field( +class DriverConfig(BaseDriverConfig): + prompt: BasePromptDriver = field( kw_only=True, default=Factory(lambda: DummyPromptDriver()), metadata={"serializable": True}, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( kw_only=True, default=Factory(lambda: DummyImageGenerationDriver()), metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( kw_only=True, default=Factory(lambda: DummyImageQueryDriver()), metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory(lambda: DummyVectorStoreDriver()), kw_only=True, metadata={"serializable": True}, ) - conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( + conversation_memory: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: DummyAudioTranscriptionDriver()), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_driver_config.py similarity index 75% rename from griptape/config/google_structure_config.py rename to griptape/config/google_driver_config.py index 66ed90b4b..a1089f0ee 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseEmbeddingDriver, BasePromptDriver, @@ -12,18 +12,18 @@ @define -class GoogleStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class GoogleDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-pro")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), kw_only=True, metadata={"serializable": True}, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=GoogleEmbeddingDriver(model="models/embedding-001")), ), diff --git a/griptape/config/logging_config.py b/griptape/config/logging_config.py new file mode 100644 index 000000000..0c0fcc020 --- /dev/null +++ b/griptape/config/logging_config.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import logging + +from attrs import define, field +from rich.logging import RichHandler + + +@define +class LoggingConfig: + logger_name: str = field(default="griptape", kw_only=True) + logger_level: int = field( + default=logging.INFO, + kw_only=True, + on_setattr=lambda self, _, value: logging.getLogger(self.logger_name).setLevel(value), + ) + + def __attrs_post_init__(self) -> None: + logger = logging.getLogger(self.logger_name) + + logger.propagate = False + logger.setLevel(self.logger_level) + + logger.handlers = [RichHandler(show_time=True, show_path=False)] diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_driver_config.py similarity index 76% rename from griptape/config/openai_structure_config.py rename to griptape/config/openai_driver_config.py index 63806dfc9..35ccde43d 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_driver_config.py @@ -1,6 +1,6 @@ from attrs import Factory, define, field -from griptape.config import StructureConfig +from griptape.config import DriverConfig from griptape.drivers import ( BaseAudioTranscriptionDriver, BaseEmbeddingDriver, @@ -20,40 +20,40 @@ @define -class OpenAiStructureConfig(StructureConfig): - prompt_driver: BasePromptDriver = field( +class OpenAiDriverConfig(DriverConfig): + prompt: BasePromptDriver = field( default=Factory(lambda: OpenAiChatPromptDriver(model="gpt-4o")), metadata={"serializable": True}, kw_only=True, ) - image_generation_driver: BaseImageGenerationDriver = field( + image_generation: BaseImageGenerationDriver = field( default=Factory(lambda: OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512")), kw_only=True, metadata={"serializable": True}, ) - image_query_driver: BaseImageQueryDriver = field( + image_query: BaseImageQueryDriver = field( default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), kw_only=True, metadata={"serializable": True}, ) - embedding_driver: BaseEmbeddingDriver = field( + embedding: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True}, kw_only=True, ) - vector_store_driver: BaseVectorStoreDriver = field( + vector_store: BaseVectorStoreDriver = field( default=Factory( lambda: LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")), ), kw_only=True, metadata={"serializable": True}, ) - text_to_speech_driver: BaseTextToSpeechDriver = field( + text_to_speech: BaseTextToSpeechDriver = field( default=Factory(lambda: OpenAiTextToSpeechDriver(model="tts")), kw_only=True, metadata={"serializable": True}, ) - audio_transcription_driver: BaseAudioTranscriptionDriver = field( + audio_transcription: BaseAudioTranscriptionDriver = field( default=Factory(lambda: OpenAiAudioTranscriptionDriver(model="whisper-1")), kw_only=True, metadata={"serializable": True}, diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index e52174c28..44f214d7c 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -5,12 +5,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils import import_optional_dependency if TYPE_CHECKING: import boto3 + from griptape.memory.structure import BaseConversationMemory + @define class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): @@ -38,6 +39,8 @@ def store(self, memory: BaseConversationMemory) -> None: ) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index 8d6399e13..f7b6e7d6e 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -2,12 +2,14 @@ import os from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory + +if TYPE_CHECKING: + from griptape.memory.structure import BaseConversationMemory @define @@ -18,6 +20,8 @@ def store(self, memory: BaseConversationMemory) -> None: Path(self.file_path).write_text(memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + if not os.path.exists(self.file_path): return None memory = BaseConversationMemory.from_json(Path(self.file_path).read_text()) diff --git a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py index 2ba3737e8..9afc2f204 100644 --- a/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/redis_conversation_memory_driver.py @@ -6,12 +6,13 @@ from attrs import Factory, define, field from griptape.drivers import BaseConversationMemoryDriver -from griptape.memory.structure import BaseConversationMemory from griptape.utils.import_utils import import_optional_dependency if TYPE_CHECKING: from redis import Redis + from griptape.memory.structure import BaseConversationMemory + @define class RedisConversationMemoryDriver(BaseConversationMemoryDriver): @@ -54,6 +55,8 @@ def store(self, memory: BaseConversationMemory) -> None: self.client.hset(self.index, self.conversation_id, memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: + from griptape.memory.structure import BaseConversationMemory + key = self.index memory_json = self.client.hget(key, self.conversation_id) if memory_json: diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index 3631b2d17..cad8287d5 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,12 +1,15 @@ -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.config import config from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: - audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True) + audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=Factory(lambda: config.drivers.audio_transcription), kw_only=True + ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: return self.audio_transcription_driver.try_run(audio) diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index af5d5a494..aad45a10a 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -11,7 +13,9 @@ @define class TextToSpeechEngine: - text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: config.drivers.text_to_speech), kw_only=True + ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: return self.text_to_speech_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index f263ee0aa..4b1184e5e 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,6 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -17,7 +18,7 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 47a853871..9bec68b91 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -3,7 +3,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -13,7 +15,9 @@ @define class BaseImageGenerationEngine(ABC): - image_generation_driver: BaseImageGenerationDriver = field(kw_only=True) + image_generation_driver: BaseImageGenerationDriver = field( + kw_only=True, default=Factory(lambda: config.drivers.image_generation) + ) @abstractmethod def run(self, prompts: list[str], *args, rulesets: Optional[list[Ruleset]], **kwargs) -> ImageArtifact: ... diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index d0a1e99d4..f2bd99544 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -2,7 +2,9 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field + +from griptape.config import config if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -11,7 +13,7 @@ @define class ImageQueryEngine: - image_query_driver: BaseImageQueryDriver = field(kw_only=True) + image_query_driver: BaseImageQueryDriver = field(default=Factory(lambda: config.drivers.image_query), kw_only=True) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: return self.image_query_driver.query(query, images) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 99bdf5f5e..92e611223 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact +from griptape.config import config from griptape.engines.rag.modules import BaseResponseRagModule from griptape.mixins import RuleMixin from griptape.utils import J2 @@ -17,7 +18,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): - prompt_driver: BasePromptDriver = field() + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index 0a07b4c50..6ce235fa5 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,6 +5,7 @@ from attrs import Factory, define, field from griptape import utils +from griptape.config import config from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -17,7 +18,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index c5d8e695d..82c33a0ad 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -6,8 +6,8 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import config from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,7 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/exceptions/dummy_exception.py b/griptape/exceptions/dummy_exception.py index 815cb245f..172aeadc6 100644 --- a/griptape/exceptions/dummy_exception.py +++ b/griptape/exceptions/dummy_exception.py @@ -2,7 +2,7 @@ class DummyError(Exception): def __init__(self, dummy_class_name: str, dummy_method_name: str) -> None: message = ( f"You have attempted to use a {dummy_class_name}'s {dummy_method_name} method. " - "This likely originated from using a `StructureConfig` without providing a Driver required for this feature." + "This likely originated from using a `DriverConfig` without providing a Driver required for this feature." ) super().__init__(message) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index c3d3c501e..d6e3549af 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -3,9 +3,10 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional -from attrs import define, field +from attrs import Factory, define, field from griptape.common import PromptStack +from griptape.config import config from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -16,7 +17,9 @@ @define class BaseConversationMemory(SerializableMixin, ABC): - driver: Optional[BaseConversationMemoryDriver] = field(default=None, kw_only=True) + driver: Optional[BaseConversationMemoryDriver] = field( + default=Factory(lambda: config.drivers.conversation_memory), kw_only=True + ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) autoload: bool = field(default=True, kw_only=True) @@ -64,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = self.structure.config.prompt_driver + prompt_driver = config.drivers.prompt temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index f29bbb767..4263e61c8 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -5,8 +5,8 @@ from attrs import Factory, define, field -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.message import Message +from griptape.common import Message, PromptStack +from griptape.config import config from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,7 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - _prompt_driver: BasePromptDriver = field(kw_only=True, default=None, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.drivers.prompt)) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) @@ -27,19 +27,6 @@ class SummaryConversationMemory(ConversationMemory): kw_only=True, ) - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set.") - return self._prompt_driver - - @prompt_driver.setter - def prompt_driver(self, value: BasePromptDriver) -> None: - self._prompt_driver = value - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 62a517bc9..18f2cff80 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Any, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.config import config from griptape.engines.rag import RagContext, RagEngine from griptape.memory.task.storage import BaseArtifactStorage @@ -15,7 +16,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): - vector_store_driver: BaseVectorStoreDriver = field() + vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: config.drivers.vector_store)) rag_engine: Optional[RagEngine] = field(default=None) retrieval_rag_module_name: Optional[str] = field(default=None) summary_engine: Optional[BaseSummaryEngine] = field(default=None) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index b133a7b6b..a046da6a9 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -2,16 +2,18 @@ from typing import TYPE_CHECKING, Callable, Optional -from attrs import Attribute, define, field +from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable +from griptape.config import config from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask if TYPE_CHECKING: from griptape.artifacts import BaseArtifact + from griptape.drivers import BasePromptDriver from griptape.tasks import BaseTask from griptape.tools import BaseTool @@ -21,6 +23,8 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) + stream: bool = field(default=Factory(lambda: config.drivers.prompt.stream), kw_only=True) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) @@ -32,11 +36,20 @@ def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FB def __attrs_post_init__(self) -> None: super().__attrs_post_init__() + + self.prompt_driver.stream = self.stream if len(self.tasks) == 0: if self.tools: - task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) + task = ToolkitTask( + self.input, + prompt_driver=self.prompt_driver, + tools=self.tools, + max_meta_memory_entries=self.max_meta_memory_entries, + ) else: - task = PromptTask(self.input, max_meta_memory_entries=self.max_meta_memory_entries) + task = PromptTask( + self.input, prompt_driver=self.prompt_driver, max_meta_memory_entries=self.max_meta_memory_entries + ) self.add_task(task) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 8f095dfeb..eb56c2cf1 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -1,24 +1,14 @@ from __future__ import annotations -import logging import uuid from abc import ABC, abstractmethod -from logging import Logger from typing import TYPE_CHECKING, Any, Optional from attrs import Attribute, Factory, define, field -from rich.logging import RichHandler from griptape.artifacts import BaseArtifact, BlobArtifact, TextArtifact from griptape.common import observable -from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig -from griptape.drivers import ( - BaseEmbeddingDriver, - BasePromptDriver, - LocalVectorStoreDriver, - OpenAiChatPromptDriver, - OpenAiEmbeddingDriver, -) +from griptape.config import config from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( @@ -31,7 +21,6 @@ from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage -from griptape.utils import deprecation_warn if TYPE_CHECKING: from griptape.memory.structure import BaseConversationMemory @@ -41,26 +30,12 @@ @define class Structure(ABC): - LOGGER_NAME = "griptape" - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - stream: Optional[bool] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None) - embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - config: BaseStructureConfig = field( - default=Factory(lambda self: self.default_config, takes_self=True), - kw_only=True, - ) rulesets: list[Ruleset] = field(factory=list, kw_only=True) rules: list[Rule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) - custom_logger: Optional[Logger] = field(default=None, kw_only=True) - logger_level: int = field(default=logging.INFO, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( - default=Factory( - lambda self: ConversationMemory(driver=self.config.conversation_memory_driver), - takes_self=True, - ), + default=Factory(lambda: ConversationMemory()), kw_only=True, ) rag_engine: RagEngine = field(default=Factory(lambda self: self.default_rag_engine, takes_self=True), kw_only=True) @@ -71,7 +46,6 @@ class Structure(ABC): meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) fail_fast: bool = field(default=True, kw_only=True) _execution_args: tuple = () - _logger: Optional[Logger] = None @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: @@ -100,39 +74,10 @@ def __attrs_post_init__(self) -> None: def __add__(self, other: BaseTask | list[BaseTask]) -> list[BaseTask]: return self.add_tasks(*other) if isinstance(other, list) else self + [other] - @prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_prompt_driver(self, attribute: Attribute, value: BasePromptDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver` instead.") - - @embedding_driver.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_embedding_driver(self, attribute: Attribute, value: BaseEmbeddingDriver) -> None: - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.embedding_driver` instead.") - - @stream.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_stream(self, attribute: Attribute, value: bool) -> None: # noqa: FBT001 - if value is not None: - deprecation_warn(f"`{attribute.name}` is deprecated, use `config.prompt_driver.stream` instead.") - @property def execution_args(self) -> tuple: return self._execution_args - @property - def logger(self) -> Logger: - if self.custom_logger: - return self.custom_logger - else: - if self._logger is None: - self._logger = logging.getLogger(self.LOGGER_NAME) - - self._logger.propagate = False - self._logger.level = self.logger_level - - self._logger.handlers = [RichHandler(show_time=True, show_path=False)] - return self._logger - @property def input_task(self) -> Optional[BaseTask]: return self.tasks[0] if self.tasks else None @@ -149,38 +94,14 @@ def output(self) -> Optional[BaseArtifact]: def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] - @property - def default_config(self) -> BaseStructureConfig: - if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: - config = StructureConfig() - - prompt_driver = OpenAiChatPromptDriver(model="gpt-4o") if self.prompt_driver is None else self.prompt_driver - - embedding_driver = OpenAiEmbeddingDriver() if self.embedding_driver is None else self.embedding_driver - - if self.stream is not None: - prompt_driver.stream = self.stream - - vector_store_driver = LocalVectorStoreDriver(embedding_driver=embedding_driver) - - config.prompt_driver = prompt_driver - config.vector_store_driver = vector_store_driver - config.embedding_driver = embedding_driver - else: - config = OpenAiStructureConfig() - - return config - @property def default_rag_engine(self) -> RagEngine: return RagEngine( retrieval_stage=RetrievalRagStage( - retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], + retrieval_modules=[VectorStoreRetrievalRagModule()], ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=self.config.prompt_driver, rulesets=self.rulesets) - ], + response_modules=[PromptResponseRagModule(prompt_driver=config.drivers.prompt, rulesets=self.rulesets)], ), ) @@ -191,10 +112,10 @@ def default_task_memory(self) -> TaskMemory: TextArtifact: TextArtifactStorage( rag_engine=self.rag_engine, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - vector_store_driver=self.config.vector_store_driver, - summary_engine=PromptSummaryEngine(prompt_driver=self.config.prompt_driver), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=self.config.prompt_driver), - json_extraction_engine=JsonExtractionEngine(prompt_driver=self.config.prompt_driver), + vector_store_driver=config.drivers.vector_store, + summary_engine=PromptSummaryEngine(prompt_driver=config.drivers.prompt), + csv_extraction_engine=CsvExtractionEngine(prompt_driver=config.drivers.prompt), + json_extraction_engine=JsonExtractionEngine(prompt_driver=config.drivers.prompt), ), BlobArtifact: BlobArtifactStorage(), }, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index d600c80a5..0f885d260 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging import re from typing import TYPE_CHECKING, Callable, Optional @@ -10,6 +11,7 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction +from griptape.config import config from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask @@ -18,6 +20,8 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory +logger = logging.getLogger(config.logging.logger_name) + @define class ActionsSubtask(BaseTask): @@ -86,7 +90,7 @@ def attach_to(self, parent_task: BaseTask) -> None: else: self.__init_from_artifacts(self.input) except Exception as e: - self.structure.logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) + logger.error("Subtask %s\nError parsing tool action: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) @@ -109,7 +113,7 @@ def before_run(self) -> None: *([f"\nThought: {self.thought}"] if self.thought is not None else []), f"\nActions: {self.actions_to_json()}", ] - self.structure.logger.info("".join(parts)) + logger.info("".join(parts)) def run(self) -> BaseArtifact: try: @@ -128,7 +132,7 @@ def run(self) -> BaseArtifact: actions_output.append(output) self.output = ListArtifact(actions_output) except Exception as e: - self.structure.logger.exception("Subtask %s\n%s", self.id, e) + logger.exception("Subtask %s\n%s", self.id, e) self.output = ErrorArtifact(str(e), exception=e) if self.output is not None: @@ -169,7 +173,7 @@ def after_run(self) -> None: subtask_actions=self.actions_to_dicts(), ), ) - self.structure.logger.info("Subtask %s\nResponse: %s", self.id, response) + logger.info("Subtask %s\nResponse: %s", self.id, response) def actions_to_dicts(self) -> list[dict]: json_list = [] @@ -257,7 +261,7 @@ def __parse_actions(self, actions_matches: list[str]) -> None: self.actions = [self.__process_action_object(action_object) for action_object in actions_list] except json.JSONDecodeError as e: - self.structure.logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e) self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e) @@ -314,10 +318,10 @@ def __validate_action(self, action: ToolAction) -> None: if activity_schema: activity_schema.validate(action.input) except schema.SchemaError as e: - self.structure.logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nInvalid action JSON: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Activity input JSON validation error: {e}", exception=e) except SyntaxError as e: - self.structure.logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) + logger.exception("Subtask %s\nSyntax error: %s", self.origin_task.id, e) action.output = ErrorArtifact(f"Syntax error: {e}", exception=e) diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py index 3a4b17b9e..3d83cf7e7 100644 --- a/griptape/tasks/audio_transcription_task.py +++ b/griptape/tasks/audio_transcription_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import AudioTranscriptionEngine from griptape.tasks.base_audio_input_task import BaseAudioInputTask @@ -13,26 +13,10 @@ @define class AudioTranscriptionTask(BaseAudioInputTask): - _audio_transcription_engine: AudioTranscriptionEngine = field( - default=None, + audio_transcription_engine: AudioTranscriptionEngine = field( + default=Factory(lambda: AudioTranscriptionEngine()), kw_only=True, - alias="audio_transcription_engine", ) - @property - def audio_transcription_engine(self) -> AudioTranscriptionEngine: - if self._audio_transcription_engine is None: - if self.structure is not None: - self._audio_transcription_engine = AudioTranscriptionEngine( - audio_transcription_driver=self.structure.config.audio_transcription_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._audio_transcription_engine - - @audio_transcription_engine.setter - def audio_transcription_engine(self, value: AudioTranscriptionEngine) -> None: - self._audio_transcription_engine = value - def run(self) -> TextArtifact: return self.audio_transcription_engine.run(self.input) diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index d2657561d..00774e0a2 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -1,21 +1,25 @@ from __future__ import annotations +import logging from abc import ABC from attrs import define +from griptape.config import config from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(config.logging.logger_name) + @define class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index 517c03a15..8a470bb85 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -1,14 +1,18 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.config.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask +logger = logging.getLogger(config.logging.logger_name) + @define class BaseAudioInputTask(RuleMixin, BaseTask, ABC): @@ -30,9 +34,9 @@ def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> N def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index d32e8f142..f94ff8de2 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os from abc import ABC from pathlib import Path @@ -7,6 +8,7 @@ from attrs import Attribute, define, field +from griptape.config import config from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -16,6 +18,9 @@ from griptape.artifacts import MediaArtifact +logger = logging.getLogger(config.logging.logger_name) + + @define class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. @@ -60,5 +65,5 @@ def all_negative_rulesets(self) -> list[Ruleset]: return task_rulesets def _read_from_file(self, path: str) -> MediaArtifact: - self.structure.logger.info("Reading image from %s", os.path.abspath(path)) + logger.info("Reading image from %s", os.path.abspath(path)) return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index a0d8cb9ac..c688a1129 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(config.logging.logger_name) + @define class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): @@ -48,9 +52,9 @@ def before_run(self) -> None: super().before_run() joined_input = "\n".join([i.to_text() for i in self.input]) - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index ade656f87..b3086bebb 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import uuid from abc import ABC, abstractmethod from concurrent import futures @@ -9,6 +10,7 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact +from griptape.config import config from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus if TYPE_CHECKING: @@ -16,6 +18,8 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure +logger = logging.getLogger(config.logging.logger_name) + @define class BaseTask(ABC): @@ -159,7 +163,7 @@ def execute(self) -> Optional[BaseArtifact]: self.after_run() except Exception as e: - self.structure.logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) + logger.exception("%s %s\n%s", self.__class__.__name__, self.id, e) self.output = ErrorArtifact(str(e), exception=e) finally: diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 90f60efcd..1c9dfc023 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -1,15 +1,19 @@ from __future__ import annotations +import logging from abc import ABC from typing import Callable from attrs import define, field from griptape.artifacts import TextArtifact +from griptape.config import config from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 +logger = logging.getLogger(config.logging.logger_name) + @define class BaseTextInputTask(RuleMixin, BaseTask, ABC): @@ -36,9 +40,9 @@ def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/csv_extraction_task.py b/griptape/tasks/csv_extraction_task.py index 538596dfe..c252893de 100644 --- a/griptape/tasks/csv_extraction_task.py +++ b/griptape/tasks/csv_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import CsvExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class CsvExtractionTask(ExtractionTask): - _extraction_engine: CsvExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> CsvExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = CsvExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: CsvExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: CsvExtractionEngine = field(default=Factory(lambda: CsvExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index d8f492693..a1c18eff0 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -13,12 +13,8 @@ @define class ExtractionTask(BaseTextInputTask): - _extraction_engine: BaseExtractionEngine = field(kw_only=True, default=None, alias="extraction_engine") + extraction_engine: BaseExtractionEngine = field(kw_only=True) args: dict = field(kw_only=True) - @property - def extraction_engine(self) -> BaseExtractionEngine: - return self._extraction_engine - def run(self) -> ListArtifact | ErrorArtifact: return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index ea1b53739..1c77bbc0a 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import ImageQueryEngine @@ -24,7 +24,7 @@ class ImageQueryTask(BaseTask): image_query_engine: The engine used to execute the query. """ - _image_query_engine: ImageQueryEngine = field(default=None, kw_only=True, alias="image_query_engine") + image_query_engine: ImageQueryEngine = field(default=Factory(lambda: ImageQueryEngine()), kw_only=True) _input: ( tuple[str, list[ImageArtifact]] | tuple[TextArtifact, list[ImageArtifact]] @@ -62,19 +62,6 @@ def input( ) -> None: self._input = value - @property - def image_query_engine(self) -> ImageQueryEngine: - if self._image_query_engine is None: - if self.structure is not None: - self._image_query_engine = ImageQueryEngine(image_query_driver=self.structure.config.image_query_driver) - else: - raise ValueError("Image Query Engine is not set.") - return self._image_query_engine - - @image_query_engine.setter - def image_query_engine(self, value: ImageQueryEngine) -> None: - self._image_query_engine = value - def run(self) -> TextArtifact: query = self.input.value[0] diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 07872d2dd..0ed28a11b 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import InpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: InpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: InpaintingImageGenerationEngine = field( + default=Factory(lambda: InpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,21 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> InpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = InpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: InpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/json_extraction_task.py b/griptape/tasks/json_extraction_task.py index ce51b316f..94db187da 100644 --- a/griptape/tasks/json_extraction_task.py +++ b/griptape/tasks/json_extraction_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from attrs import define, field +from attrs import Factory, define, field from griptape.engines import JsonExtractionEngine from griptape.tasks import ExtractionTask @@ -8,17 +8,4 @@ @define class JsonExtractionTask(ExtractionTask): - _extraction_engine: JsonExtractionEngine = field(default=None, kw_only=True, alias="extraction_engine") - - @property - def extraction_engine(self) -> JsonExtractionEngine: - if self._extraction_engine is None: - if self.structure is not None: - self._extraction_engine = JsonExtractionEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Extraction Engine is not set.") - return self._extraction_engine - - @extraction_engine.setter - def extraction_engine(self, value: JsonExtractionEngine) -> None: - self._extraction_engine = value + extraction_engine: JsonExtractionEngine = field(default=Factory(lambda: JsonExtractionEngine()), kw_only=True) diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index 3fc85a084..6b63709db 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import OutpaintingImageGenerationEngine @@ -28,10 +28,9 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: OutpaintingImageGenerationEngine = field( - default=None, + image_generation_engine: OutpaintingImageGenerationEngine = field( + default=Factory(lambda: OutpaintingImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: ( tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact @@ -60,22 +59,6 @@ def input( ) -> None: self._input = value - @property - def image_generation_engine(self) -> OutpaintingImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = OutpaintingImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: OutpaintingImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index 0e06448bc..4d3356392 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.engines import PromptImageGenerationEngine @@ -32,10 +32,9 @@ class PromptImageGenerationTask(BaseImageGenerationTask): _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field( default=DEFAULT_INPUT_TEMPLATE, alias="input" ) - _image_generation_engine: PromptImageGenerationEngine = field( - default=None, + image_generation_engine: PromptImageGenerationEngine = field( + default=Factory(lambda: PromptImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) @property @@ -51,21 +50,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def image_generation_engine(self) -> PromptImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = PromptImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: PromptImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: image_artifact = self.image_generation_engine.run( prompts=[self.input.to_text()], diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 386ebe239..a8038832d 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,11 +1,13 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Callable, Optional from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack +from griptape.config import config from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -13,10 +15,12 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver +logger = logging.getLogger(config.logging.logger_name) + @define class PromptTask(RuleMixin, BaseTask): - _prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True, alias="prompt_driver") + prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.drivers.prompt), kw_only=True) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, @@ -56,15 +60,6 @@ def prompt_stack(self) -> PromptStack: return stack - @property - def prompt_driver(self) -> BasePromptDriver: - if self._prompt_driver is None: - if self.structure is not None: - self._prompt_driver = self.structure.config.prompt_driver - else: - raise ValueError("Prompt Driver is not set") - return self._prompt_driver - def default_system_template_generator(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets), @@ -73,12 +68,12 @@ def default_system_template_generator(self, _: PromptTask) -> str: def before_run(self) -> None: super().before_run() - self.structure.logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) + logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text()) def after_run(self) -> None: super().after_run() - self.structure.logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) + logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) def run(self) -> BaseArtifact: message = self.prompt_driver.run(self.prompt_stack) diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index 2f44fdfa4..b7ea8d7c7 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -1,32 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact +from griptape.engines.rag import RagEngine from griptape.tasks import BaseTextInputTask -if TYPE_CHECKING: - from griptape.engines.rag import RagEngine - @define class RagTask(BaseTextInputTask): - _rag_engine: RagEngine = field(kw_only=True, default=None, alias="rag_engine") - - @property - def rag_engine(self) -> RagEngine: - if self._rag_engine is None: - if self.structure is not None: - self._rag_engine = self.structure.rag_engine - else: - raise ValueError("rag_engine is not set.") - return self._rag_engine - - @rag_engine.setter - def rag_engine(self, value: RagEngine) -> None: - self._rag_engine = value + rag_engine: RagEngine = field(kw_only=True, default=Factory(lambda: RagEngine())) def run(self) -> BaseArtifact: outputs = self.rag_engine.process_query(self.input.to_text()).outputs diff --git a/griptape/tasks/text_summary_task.py b/griptape/tasks/text_summary_task.py index 5bd1b547e..dc1a7b8be 100644 --- a/griptape/tasks/text_summary_task.py +++ b/griptape/tasks/text_summary_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import PromptSummaryEngine @@ -14,20 +14,7 @@ @define class TextSummaryTask(BaseTextInputTask): - _summary_engine: Optional[BaseSummaryEngine] = field(default=None, alias="summary_engine") - - @property - def summary_engine(self) -> Optional[BaseSummaryEngine]: - if self._summary_engine is None: - if self.structure is not None: - self._summary_engine = PromptSummaryEngine(prompt_driver=self.structure.config.prompt_driver) - else: - raise ValueError("Summary Engine is not set.") - return self._summary_engine - - @summary_engine.setter - def summary_engine(self, value: BaseSummaryEngine) -> None: - self._summary_engine = value + summary_engine: BaseSummaryEngine = field(default=Factory(lambda: PromptSummaryEngine()), kw_only=True) def run(self) -> TextArtifact: return TextArtifact(self.summary_engine.summarize_text(self.input.to_text(), rulesets=self.all_rulesets)) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index 3ca503dfe..680a67603 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import TextArtifact from griptape.engines import TextToSpeechEngine @@ -19,7 +19,7 @@ class TextToSpeechTask(BaseAudioGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _text_to_speech_engine: TextToSpeechEngine = field(default=None, kw_only=True, alias="text_to_speech_engine") + text_to_speech_engine: TextToSpeechEngine = field(default=Factory(lambda: TextToSpeechEngine()), kw_only=True) @property def input(self) -> TextArtifact: @@ -34,21 +34,6 @@ def input(self) -> TextArtifact: def input(self, value: TextArtifact) -> None: self._input = value - @property - def text_to_speech_engine(self) -> TextToSpeechEngine: - if self._text_to_speech_engine is None: - if self.structure is not None: - self._text_to_speech_engine = TextToSpeechEngine( - text_to_speech_driver=self.structure.config.text_to_speech_driver, - ) - else: - raise ValueError("Audio Generation Engine is not set.") - return self._text_to_speech_engine - - @text_to_speech_engine.setter - def text_to_speech_engine(self, value: TextToSpeechEngine) -> None: - self._text_to_speech_engine = value - def run(self) -> AudioArtifact: audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index c162a7192..e3feaeac5 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -2,7 +2,7 @@ from typing import Callable -from attrs import define, field +from attrs import Factory, define, field from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact from griptape.engines import VariationImageGenerationEngine @@ -28,10 +28,9 @@ class VariationImageGenerationTask(BaseImageGenerationTask): output_file: If provided, the generated image will be written to disk as output_file. """ - _image_generation_engine: VariationImageGenerationEngine = field( - default=None, + image_generation_engine: VariationImageGenerationEngine = field( + default=Factory(lambda: VariationImageGenerationEngine()), kw_only=True, - alias="image_generation_engine", ) _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact = field( default=None, alias="input" @@ -57,21 +56,6 @@ def input(self) -> ListArtifact: def input(self, value: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact]) -> None: self._input = value - @property - def image_generation_engine(self) -> VariationImageGenerationEngine: - if self._image_generation_engine is None: - if self.structure is not None: - self._image_generation_engine = VariationImageGenerationEngine( - image_generation_driver=self.structure.config.image_generation_driver, - ) - else: - raise ValueError("Image Generation Engine is not set.") - return self._image_generation_engine - - @image_generation_engine.setter - def image_generation_engine(self, value: VariationImageGenerationEngine) -> None: - self._image_generation_engine = value - def run(self) -> ImageArtifact: prompt_artifact = self.input[0] diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index e98eeaa4d..07fea92d8 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,12 +25,19 @@ class Chat: ) def default_output_fn(self, text: str) -> None: - if self.structure.config.prompt_driver.stream: + from griptape.tasks.prompt_task import PromptTask + + streaming_tasks = [ + task for task in self.structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if streaming_tasks: print(text, end="", flush=True) # noqa: T201 else: print(text) # noqa: T201 def start(self) -> None: + from griptape.config import config + if self.intro_text: self.output_fn(self.intro_text) while True: @@ -40,7 +47,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if self.structure.config.prompt_driver.stream: + if config.drivers.prompt.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index fd64a0f52..6da58b9e6 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -34,8 +34,13 @@ class Stream: @structure.validator # pyright: ignore[reportAttributeAccessIssue] def validate_structure(self, _: Attribute, structure: Structure) -> None: - if not structure.config.prompt_driver.stream: - raise ValueError("prompt driver does not have streaming enabled, enable with stream=True") + from griptape.tasks import PromptTask + + streaming_tasks = [ + task for task in structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream + ] + if not streaming_tasks: + raise ValueError("Structure does not have any streaming tasks, enable with stream=True") _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) diff --git a/tests/mocks/docker/fake_api.py b/tests/mocks/docker/fake_api.py index 881093057..00e750232 100644 --- a/tests/mocks/docker/fake_api.py +++ b/tests/mocks/docker/fake_api.py @@ -154,7 +154,7 @@ def get_fake_inspect_container(*, tty=False): status_code = 200 response = { "Id": FAKE_CONTAINER_ID, - "Config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, + "config": {"Labels": {"foo": "bar"}, "Privileged": True, "Tty": tty}, "ID": FAKE_CONTAINER_ID, "Image": "busybox:latest", "Name": "foobar", @@ -166,7 +166,7 @@ def get_fake_inspect_container(*, tty=False): "StartedAt": "2013-09-25T14:01:18.869545111+02:00", "Ghost": False, }, - "HostConfig": {"LogConfig": {"Type": "json-file", "Config": {}}}, + "HostConfig": {"LogConfig": {"Type": "json-file", "config": {}}}, "MacAddress": "02:42:ac:11:00:0a", } return status_code, response @@ -179,7 +179,7 @@ def get_fake_inspect_image(): "Parent": "27cf784147099545", "Created": "2013-03-23T22:24:18.818426-07:00", "Container": FAKE_CONTAINER_ID, - "Config": {"Labels": {"bar": "foo"}}, + "config": {"Labels": {"bar": "foo"}}, "ContainerConfig": { "Hostname": "", "User": "", @@ -446,7 +446,7 @@ def get_fake_network_list(): "Driver": "bridge", "EnableIPv6": False, "Internal": False, - "IPAM": {"Driver": "default", "Config": [{"Subnet": "172.17.0.0/16"}]}, + "IPAM": {"Driver": "default", "config": [{"Subnet": "172.17.0.0/16"}]}, "Containers": { FAKE_CONTAINER_ID: { "EndpointID": "ed2419a97c1d99", diff --git a/tests/mocks/mock_driver_config.py b/tests/mocks/mock_driver_config.py new file mode 100644 index 000000000..6b152721d --- /dev/null +++ b/tests/mocks/mock_driver_config.py @@ -0,0 +1,27 @@ +from attrs import Factory, define, field + +from griptape.config import DriverConfig +from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver +from tests.mocks.mock_embedding_driver import MockEmbeddingDriver +from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver +from tests.mocks.mock_image_query_driver import MockImageQueryDriver +from tests.mocks.mock_prompt_driver import MockPromptDriver + + +@define +class MockDriverConfig(DriverConfig): + prompt: MockPromptDriver = field(default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True}) + image_generation: MockImageGenerationDriver = field( + default=Factory(lambda: MockImageGenerationDriver()), + metadata={"serializable": True}, + ) + image_query: MockImageQueryDriver = field( + default=Factory(lambda: MockImageQueryDriver()), metadata={"serializable": True} + ) + embedding: MockEmbeddingDriver = field( + default=Factory(lambda: MockEmbeddingDriver()), metadata={"serializable": True} + ) + vector_store: LocalVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True), + metadata={"serializable": True}, + ) diff --git a/tests/mocks/mock_image_generation_driver.py b/tests/mocks/mock_image_generation_driver.py index 573eb0fc4..f8d6d89ce 100644 --- a/tests/mocks/mock_image_generation_driver.py +++ b/tests/mocks/mock_image_generation_driver.py @@ -10,6 +10,8 @@ @define class MockImageGenerationDriver(BaseImageGenerationDriver): + model: str = "test-model" + def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: return ImageArtifact(value="mock image", width=512, height=512, format="png") diff --git a/tests/mocks/mock_structure_config.py b/tests/mocks/mock_structure_config.py deleted file mode 100644 index 3f95288f4..000000000 --- a/tests/mocks/mock_structure_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from attrs import Factory, define, field - -from griptape.config import StructureConfig -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_image_query_driver import MockImageQueryDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver - - -@define -class MockStructureConfig(StructureConfig): - prompt_driver: MockPromptDriver = field( - default=Factory(lambda: MockPromptDriver()), metadata={"serializable": True} - ) - image_generation_driver: MockImageGenerationDriver = field( - default=Factory(lambda: MockImageGenerationDriver(model="dall-e-2")), metadata={"serializable": True} - ) - image_query_driver: MockImageQueryDriver = field( - default=Factory(lambda: MockImageQueryDriver(model="gpt-4-vision-preview")), metadata={"serializable": True} - ) - embedding_driver: MockEmbeddingDriver = field( - default=Factory(lambda: MockEmbeddingDriver(model="text-embedding-3-small")), metadata={"serializable": True} - ) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_driver_config.py similarity index 71% rename from tests/unit/config/test_amazon_bedrock_structure_config.py rename to tests/unit/config/test_amazon_bedrock_driver_config.py index afe9b3720..57a80809e 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_driver_config.py @@ -1,11 +1,11 @@ import boto3 import pytest -from griptape.config import AmazonBedrockStructureConfig +from griptape.config import AmazonBedrockDriverConfig from tests.utils.aws import mock_aws_credentials -class TestAmazonBedrockStructureConfig: +class TestAmazonBedrockDriverConfig: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() @@ -13,11 +13,11 @@ def _run_before_and_after_tests(self): @pytest.fixture() def config(self): mock_aws_credentials() - return AmazonBedrockStructureConfig() + return AmazonBedrockDriverConfig() @pytest.fixture() def config_with_values(self): - return AmazonBedrockStructureConfig( + return AmazonBedrockDriverConfig( session=boto3.Session( aws_access_key_id="testing", aws_secret_access_key="testing", region_name="region-value" ) @@ -25,9 +25,9 @@ def config_with_values(self): def test_to_dict(self, config): assert config.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -40,13 +40,13 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -55,32 +55,31 @@ def test_to_dict(self, config): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AmazonBedrockStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AmazonBedrockDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() def test_from_dict_with_values(self, config_with_values): assert ( - AmazonBedrockStructureConfig.from_dict(config_with_values.to_dict()).to_dict() - == config_with_values.to_dict() + AmazonBedrockDriverConfig.from_dict(config_with_values.to_dict()).to_dict() == config_with_values.to_dict() ) def test_to_dict_with_values(self, config_with_values): assert config_with_values.to_dict() == { - "conversation_memory_driver": None, - "embedding_driver": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, - "image_generation_driver": { + "conversation_memory": None, + "embedding": {"model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver"}, + "image_generation": { "image_generation_model_driver": { "cfg_scale": 7, "outpainting_mode": "PRECISE", @@ -93,13 +92,13 @@ def test_to_dict_with_values(self, config_with_values): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "type": "AmazonBedrockImageQueryDriver", "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, - "prompt_driver": { + "prompt": { "max_tokens": None, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, @@ -108,15 +107,15 @@ def test_to_dict_with_values(self, config_with_values): "tool_choice": {"auto": {}}, "use_native_tools": True, }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "model": "amazon.titan-embed-text-v1", "type": "AmazonBedrockTitanEmbeddingDriver", }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockStructureConfig", - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "type": "AmazonBedrockDriverConfig", + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } assert config_with_values.session.region_name == "region-value" diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_driver_config.py similarity index 64% rename from tests/unit/config/test_anthropic_structure_config.py rename to tests/unit/config/test_anthropic_driver_config.py index 05519fa5e..a2ccbd25b 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_driver_config.py @@ -1,9 +1,9 @@ import pytest -from griptape.config import AnthropicStructureConfig +from griptape.config import AnthropicDriverConfig -class TestAnthropicStructureConfig: +class TestAnthropicDriverConfig: @pytest.fixture(autouse=True) def _mock_anthropic(self, mocker): mocker.patch("anthropic.Anthropic") @@ -11,12 +11,12 @@ def _mock_anthropic(self, mocker): @pytest.fixture() def config(self): - return AnthropicStructureConfig() + return AnthropicDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "AnthropicStructureConfig", - "prompt_driver": { + "type": "AnthropicDriverConfig", + "prompt": { "type": "AnthropicPromptDriver", "temperature": 0.1, "max_tokens": 1000, @@ -26,18 +26,18 @@ def test_to_dict(self, config): "top_k": 250, "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": { "type": "AnthropicImageQueryDriver", "model": "claude-3-5-sonnet-20240620", "max_tokens": 256, }, - "embedding_driver": { + "embedding": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", "input_type": "document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "VoyageAiEmbeddingDriver", @@ -45,10 +45,10 @@ def test_to_dict(self, config): "input_type": "document", }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AnthropicStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AnthropicDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_driver_config.py similarity index 70% rename from tests/unit/config/test_azure_openai_structure_config.py rename to tests/unit/config/test_azure_openai_driver_config.py index dcdc3a1dc..3c88b859d 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_driver_config.py @@ -1,16 +1,16 @@ import pytest -from griptape.config import AzureOpenAiStructureConfig +from griptape.config import AzureOpenAiDriverConfig -class TestAzureOpenAiStructureConfig: +class TestAzureOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") @pytest.fixture() def config(self): - return AzureOpenAiStructureConfig( + return AzureOpenAiDriverConfig( azure_endpoint="http://localhost:8080", azure_ad_token="test-token", azure_ad_token_provider=lambda: "test-provider", @@ -18,9 +18,9 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "type": "AzureOpenAiStructureConfig", + "type": "AzureOpenAiDriverConfig", "azure_endpoint": "http://localhost:8080", - "prompt_driver": { + "prompt": { "type": "AzureOpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -36,8 +36,8 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": "2024-02-01", "base_url": None, "image_size": "512x512", @@ -59,7 +59,7 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "base_url": None, "image_quality": "auto", "max_tokens": 256, @@ -70,7 +70,7 @@ def test_to_dict(self, config): "organization": None, "type": "AzureOpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -82,19 +82,6 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } - - def test_from_dict(self, config: AzureOpenAiStructureConfig): - assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - # override values in the dict config - # serialize and deserialize the config - new_config = config.merge_config( - { - "prompt_driver": {"azure_deployment": "new-test-gpt-4"}, - "embedding_driver": {"model": "new-text-embedding-3-small"}, - } - ).to_dict() - assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_driver_config.py similarity index 57% rename from tests/unit/config/test_cohere_structure_config.py rename to tests/unit/config/test_cohere_driver_config.py index 113a589ec..9e8407d84 100644 --- a/tests/unit/config/test_cohere_structure_config.py +++ b/tests/unit/config/test_cohere_driver_config.py @@ -1,22 +1,22 @@ import pytest -from griptape.config import CohereStructureConfig +from griptape.config import CohereDriverConfig -class TestCohereStructureConfig: +class TestCohereDriverConfig: @pytest.fixture() def config(self): - return CohereStructureConfig(api_key="api_key") + return CohereDriverConfig(api_key="api_key") def test_to_dict(self, config): assert config.to_dict() == { - "type": "CohereStructureConfig", - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - "prompt_driver": { + "type": "CohereDriverConfig", + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + "prompt": { "type": "CoherePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -25,12 +25,12 @@ def test_to_dict(self, config): "force_single_step": False, "use_native_tools": True, }, - "embedding_driver": { + "embedding": { "type": "CohereEmbeddingDriver", "model": "embed-english-v3.0", "input_type": "search_document", }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "CohereEmbeddingDriver", diff --git a/tests/unit/config/test_driver_config.py b/tests/unit/config/test_driver_config.py new file mode 100644 index 000000000..dd3fd1a47 --- /dev/null +++ b/tests/unit/config/test_driver_config.py @@ -0,0 +1,39 @@ +import pytest + +from griptape.config import DriverConfig + + +class TestDriverConfig: + @pytest.fixture() + def config(self): + return DriverConfig() + + def test_to_dict(self, config): + assert config.to_dict() == { + "type": "DriverConfig", + "prompt": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "use_native_tools": False, + }, + "conversation_memory": None, + "embedding": {"type": "DummyEmbeddingDriver"}, + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "vector_store": { + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "type": "DummyVectorStoreDriver", + }, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, + } + + def test_from_dict(self, config): + assert DriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + + def test_dot_update(self, config): + config.prompt.max_tokens = 10 + + assert config.prompt.max_tokens == 10 diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_driver_config.py similarity index 62% rename from tests/unit/config/test_google_structure_config.py rename to tests/unit/config/test_google_driver_config.py index e193cc983..fb6cd23b5 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_driver_config.py @@ -1,21 +1,21 @@ import pytest -from griptape.config import GoogleStructureConfig +from griptape.config import GoogleDriverConfig -class TestGoogleStructureConfig: +class TestGoogleDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("google.generativeai.GenerativeModel") @pytest.fixture() def config(self): - return GoogleStructureConfig() + return GoogleDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "GoogleStructureConfig", - "prompt_driver": { + "type": "GoogleDriverConfig", + "prompt": { "type": "GooglePromptDriver", "temperature": 0.1, "max_tokens": None, @@ -26,15 +26,15 @@ def test_to_dict(self, config): "tool_choice": "auto", "use_native_tools": True, }, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "embedding_driver": { + "image_generation": {"type": "DummyImageGenerationDriver"}, + "image_query": {"type": "DummyImageQueryDriver"}, + "embedding": { "type": "GoogleEmbeddingDriver", "model": "models/embedding-001", "task_type": "retrieval_document", "title": None, }, - "vector_store_driver": { + "vector_store": { "type": "LocalVectorStoreDriver", "embedding_driver": { "type": "GoogleEmbeddingDriver", @@ -43,10 +43,10 @@ def test_to_dict(self, config): "title": None, }, }, - "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "conversation_memory": None, + "text_to_speech": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert GoogleStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert GoogleDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_driver_config.py similarity index 80% rename from tests/unit/config/test_openai_structure_config.py rename to tests/unit/config/test_openai_driver_config.py index 8969e0ad0..55156730c 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_driver_config.py @@ -1,21 +1,21 @@ import pytest -from griptape.config import OpenAiStructureConfig +from griptape.config import OpenAiDriverConfig -class TestOpenAiStructureConfig: +class TestOpenAiDriverConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.OpenAI") @pytest.fixture() def config(self): - return OpenAiStructureConfig() + return OpenAiDriverConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "OpenAiStructureConfig", - "prompt_driver": { + "type": "OpenAiDriverConfig", + "prompt": { "type": "OpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", @@ -28,14 +28,14 @@ def test_to_dict(self, config): "user": "", "use_native_tools": True, }, - "conversation_memory_driver": None, - "embedding_driver": { + "conversation_memory": None, + "embedding": { "base_url": None, "model": "text-embedding-3-small", "organization": None, "type": "OpenAiEmbeddingDriver", }, - "image_generation_driver": { + "image_generation": { "api_version": None, "base_url": None, "image_size": "512x512", @@ -46,7 +46,7 @@ def test_to_dict(self, config): "style": None, "type": "OpenAiImageGenerationDriver", }, - "image_query_driver": { + "image_query": { "api_version": None, "base_url": None, "image_quality": "auto", @@ -55,7 +55,7 @@ def test_to_dict(self, config): "organization": None, "type": "OpenAiImageQueryDriver", }, - "vector_store_driver": { + "vector_store": { "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -64,7 +64,7 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "text_to_speech_driver": { + "text_to_speech": { "type": "OpenAiTextToSpeechDriver", "api_version": None, "base_url": None, @@ -73,7 +73,7 @@ def test_to_dict(self, config): "organization": None, "voice": "alloy", }, - "audio_transcription_driver": { + "audio_transcription": { "type": "OpenAiAudioTranscriptionDriver", "api_version": None, "base_url": None, @@ -83,4 +83,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert OpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert OpenAiDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py deleted file mode 100644 index 96a68628f..000000000 --- a/tests/unit/config/test_structure_config.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest - -from griptape.config import StructureConfig - - -class TestStructureConfig: - @pytest.fixture() - def config(self): - return StructureConfig() - - def test_to_dict(self, config): - assert config.to_dict() == { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "use_native_tools": False, - }, - "conversation_memory_driver": None, - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - } - - def test_from_dict(self, config): - assert StructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - def test_unchanged_merge_config(self, config): - assert ( - config.merge_config( - { - "type": "StructureConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - }, - } - ).to_dict() - == config.to_dict() - ) - - def test_changed_merge_config(self, config): - config = config.merge_config( - {"prompt_driver": {"type": "DummyPromptDriver", "temperature": 0.1, "max_tokens": None, "stream": False}} - ) - - assert config.prompt_driver.temperature == 0.1 - - def test_dot_update(self, config): - config.prompt_driver.max_tokens = 10 - - assert config.prompt_driver.max_tokens == 10 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e462ede90..8a37f6d28 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,6 +1,8 @@ import pytest +from griptape.config import config from griptape.events import event_bus +from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) @@ -10,3 +12,10 @@ def mock_event_bus(): yield event_bus event_bus.clear_event_listeners() + + +@pytest.fixture(autouse=True) +def mock_config(): + config.drivers = MockDriverConfig() + + return config diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 6fcab26e5..61ef3aa53 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -12,7 +12,7 @@ class TestBaseAudioTranscriptionDriver: def driver(self): return MockAudioTranscriptionDriver() - def test_run_publish_events(self, driver): + def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() event_bus.add_event_listener(EventListener(handler=mock_handler)) diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index 8e700d0a5..f1a5df1be 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -6,7 +6,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.aws import mock_aws_credentials @@ -40,7 +39,6 @@ def test_store(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -49,7 +47,7 @@ def test_store(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -65,7 +63,6 @@ def test_store_with_sort_key(self): session = boto3.Session(region_name=self.AWS_REGION) dynamodb = session.resource("dynamodb") table = dynamodb.Table(self.DYNAMODB_TABLE_NAME) - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=session, table_name=self.DYNAMODB_TABLE_NAME, @@ -76,7 +73,7 @@ def test_store_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -89,7 +86,6 @@ def test_store_with_sort_key(self): assert "Item" in response def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -98,7 +94,7 @@ def test_load(self): partition_key_value=self.PARTITION_KEY_VALUE, ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -113,7 +109,6 @@ def test_load(self): assert new_memory.runs[0].output.value == "mock output" def test_load_with_sort_key(self): - prompt_driver = MockPromptDriver() memory_driver = AmazonDynamoDbConversationMemoryDriver( session=boto3.Session(region_name=self.AWS_REGION), table_name=self.DYNAMODB_TABLE_NAME, @@ -124,7 +119,7 @@ def test_load_with_sort_key(self): sort_key_value="foo", ) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index e1a383ab9..dff66d0fc 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -7,7 +7,6 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestLocalConversationMemoryDriver: @@ -22,10 +21,9 @@ def _run_before_and_after_tests(self): self.__delete_file(self.MEMORY_FILE_PATH) def test_store(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -41,10 +39,9 @@ def test_store(self): assert True def test_load(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver, autoload=False, max_runs=5) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) @@ -60,10 +57,9 @@ def test_load(self): assert new_memory.max_runs == 5 def test_autoload(self): - prompt_driver = MockPromptDriver() memory_driver = LocalConversationMemoryDriver(file_path=self.MEMORY_FILE_PATH) memory = ConversationMemory(driver=memory_driver) - pipeline = Pipeline(prompt_driver=prompt_driver, conversation_memory=memory) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_task(PromptTask("test")) diff --git a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py index 4f7ce50f0..758505b26 100644 --- a/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py +++ b/tests/unit/drivers/observability/test_open_telemetry_observability_driver.py @@ -8,7 +8,6 @@ from griptape.drivers import OpenTelemetryObservabilityDriver from griptape.observability.observability import Observability from griptape.structures.agent import Agent -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils.expected_spans import ExpectedSpan, ExpectedSpans @@ -170,7 +169,7 @@ def test_observability_agent(self, driver, mock_span_exporter): ) with Observability(observability_driver=driver): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.run("Hi") assert mock_span_exporter.export.call_count == 1 diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b6b0c600..c30acdec4 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -10,17 +10,17 @@ class TestBasePromptDriver: - def test_run_via_pipeline_retries_success(self): - driver = MockPromptDriver(max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_success(self, mock_config): + mock_config.drivers.prompt = MockPromptDriver(max_attempts=2) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) assert isinstance(pipeline.run().output_task.output, TextArtifact) - def test_run_via_pipeline_retries_failure(self): - driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) - pipeline = Pipeline(prompt_driver=driver) + def test_run_via_pipeline_retries_failure(self, mock_config): + mock_config.drivers.prompt = MockFailingPromptDriver(max_failures=2, max_attempts=1) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -28,8 +28,7 @@ def test_run_via_pipeline_retries_failure(self): def test_run_via_pipeline_publishes_events(self, mocker): mock_publish_event = mocker.patch.object(_EventBus, "publish_event") - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver) + pipeline = Pipeline() pipeline.add_task(PromptTask("test")) pipeline.run() @@ -46,9 +45,9 @@ def test_run_with_stream(self): assert isinstance(result, Message) assert result.value == "mock output" - def test_run_with_tools(self): - driver = MockPromptDriver(max_attempts=1, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools(self, mock_config): + mock_config.drivers.prompt = MockPromptDriver(max_attempts=1, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) @@ -56,9 +55,9 @@ def test_run_with_tools(self): assert isinstance(output, TextArtifact) assert output.value == "mock output" - def test_run_with_tools_and_stream(self): - driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) - pipeline = Pipeline(prompt_driver=driver) + def test_run_with_tools_and_stream(self, mock_config): + mock_config.driver = MockPromptDriver(max_attempts=1, stream=True, use_native_tools=True) + pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 316f7bf71..c2bb45208 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -9,7 +9,7 @@ class TestLocalStructureRunDriver: def test_run(self): pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent(prompt_driver=MockPromptDriver())) + driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent()) task = StructureRunTask(driver=driver) @@ -17,10 +17,11 @@ def test_run(self): assert task.run().to_text() == "mock output" - def test_run_with_env(self): + def test_run_with_env(self, mock_config): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["KEY"])) + mock_config.drivers.prompt = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index f69d8a0ba..d84fc7cdd 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,13 +1,12 @@ import pytest from griptape.engines import CsvExtractionEngine -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestCsvExtractionEngine: @pytest.fixture() def engine(self): - return CsvExtractionEngine(prompt_driver=MockPromptDriver()) + return CsvExtractionEngine() def test_extract(self, engine): result = engine.extract("foo", column_names=["test1"]) diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 4d0aad139..430f67ef9 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -4,13 +4,12 @@ from griptape.common import Reference from griptape.engines.rag import RagContext from griptape.engines.rag.modules import FootnotePromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFootnotePromptResponseRagModule: @pytest.fixture() def module(self): - return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) + return FootnotePromptResponseRagModule() def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index 964a52650..71db4e01f 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -25,14 +25,12 @@ def engine(self): ) def test_module_name_uniqueness(self): - vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) - with pytest.raises(ValueError): RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test"), + VectorStoreRetrievalRagModule(name="test"), ] ) ) @@ -40,8 +38,8 @@ def test_module_name_uniqueness(self): assert RagEngine( retrieval_stage=RetrievalRagStage( retrieval_modules=[ - VectorStoreRetrievalRagModule(name="test1", vector_store_driver=vector_store_driver), - VectorStoreRetrievalRagModule(name="test2", vector_store_driver=vector_store_driver), + VectorStoreRetrievalRagModule(name="test1"), + VectorStoreRetrievalRagModule(name="test2"), ] ) ) diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 4d9c65e03..138444ae3 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -12,7 +12,7 @@ class TestPromptSummaryEngine: @pytest.fixture() def engine(self): - return PromptSummaryEngine(prompt_driver=MockPromptDriver()) + return PromptSummaryEngine() def test_summarize_text(self, engine): assert engine.summarize_text("foobar") == "mock output" @@ -24,10 +24,10 @@ def test_summarize_artifacts(self, engine): def test_max_token_multiplier_invalid(self, engine): with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=0) + PromptSummaryEngine(max_token_multiplier=0) with pytest.raises(ValueError): - PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) + PromptSummaryEngine(max_token_multiplier=10000) def test_chunked_summary(self, engine): def smaller_input(prompt_stack: PromptStack): diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 50763e0c3..4e21fa392 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -25,16 +25,17 @@ class TestEventListener: @pytest.fixture() - def pipeline(self): + def pipeline(self, mock_config): + mock_config.drivers.prompt = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) - pipeline = Pipeline(prompt_driver=MockPromptDriver(stream=True)) + pipeline = Pipeline() pipeline.add_task(task) task.add_subtask(ActionsSubtask("foo")) return pipeline - def test_untyped_listeners(self, pipeline): + def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() @@ -48,7 +49,7 @@ def test_untyped_listeners(self, pipeline): assert event_handler_1.call_count == 9 assert event_handler_2.call_count == 9 - def test_typed_listeners(self, pipeline): + def test_typed_listeners(self, pipeline, mock_config): start_prompt_event_handler = Mock() finish_prompt_event_handler = Mock() start_task_event_handler = Mock() diff --git a/tests/unit/events/test_finish_actions_subtask_event.py b/tests/unit/events/test_finish_actions_subtask_event.py index 5e2a0807a..5fc35755b 100644 --- a/tests/unit/events/test_finish_actions_subtask_event.py +++ b/tests/unit/events/test_finish_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import FinishActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def finish_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_finish_task_event.py b/tests/unit/events/test_finish_task_event.py index df1d6d42a..2568752bb 100644 --- a/tests/unit/events/test_finish_task_event.py +++ b/tests/unit/events/test_finish_task_event.py @@ -3,14 +3,13 @@ from griptape.events import FinishTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestFinishTaskEvent: @pytest.fixture() def finish_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/events/test_start_actions_subtask_event.py b/tests/unit/events/test_start_actions_subtask_event.py index 8b628057c..b7236911f 100644 --- a/tests/unit/events/test_start_actions_subtask_event.py +++ b/tests/unit/events/test_start_actions_subtask_event.py @@ -3,7 +3,6 @@ from griptape.events import StartActionsSubtaskEvent from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -17,7 +16,7 @@ def start_subtask_event(self): "Answer: test output" ) task = ToolkitTask(tools=[MockTool()]) - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) subtask = ActionsSubtask(valid_input) task.add_subtask(subtask) diff --git a/tests/unit/events/test_start_task_event.py b/tests/unit/events/test_start_task_event.py index ea027f147..111d35934 100644 --- a/tests/unit/events/test_start_task_event.py +++ b/tests/unit/events/test_start_task_event.py @@ -3,14 +3,13 @@ from griptape.events import StartTaskEvent from griptape.structures import Agent from griptape.tasks import PromptTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStartTaskEvent: @pytest.fixture() def start_task_event(self): task = PromptTask() - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() agent.add_task(task) agent.run() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 2ffd7b8cb..f0e4b0af3 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -60,7 +60,7 @@ def test_from_json(self): def test_buffering(self): memory = ConversationMemory(max_runs=2) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask()) @@ -75,7 +75,7 @@ def test_buffering(self): assert pipeline.conversation_memory.runs[1].input.value == "run5" def test_add_to_prompt_stack_autopruing_disabled(self): - agent = Agent(prompt_driver=MockPromptDriver()) + agent = Agent() memory = ConversationMemory( autoprune=False, runs=[ @@ -94,9 +94,11 @@ def test_add_to_prompt_stack_autopruing_disabled(self): assert len(prompt_stack.messages) == 12 - def test_add_to_prompt_stack_autopruning_enabled(self): + def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) + + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -117,7 +119,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): assert len(prompt_stack.messages) == 3 # No memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ @@ -140,7 +143,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) + mock_config.drivers.prompt = MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160)) + agent = Agent() memory = ConversationMemory( autoprune=True, runs=[ diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 4396c7b23..42246e349 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -5,14 +5,13 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestSummaryConversationMemory: def test_unsummarized_subtasks(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) @@ -24,9 +23,9 @@ def test_unsummarized_subtasks(self): assert len(memory.unsummarized_runs()) == 1 def test_after_run(self): - memory = SummaryConversationMemory(offset=1, prompt_driver=MockPromptDriver()) + memory = SummaryConversationMemory(offset=1) - pipeline = Pipeline(conversation_memory=memory, prompt_driver=MockPromptDriver()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) @@ -85,7 +84,7 @@ def test_from_json(self): def test_config_prompt_driver(self): memory = SummaryConversationMemory() - pipeline = Pipeline(conversation_memory=memory, config=MockStructureConfig()) + pipeline = Pipeline(conversation_memory=memory) pipeline.add_tasks(PromptTask("test")) diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index e3d9034c4..33bfdc5ee 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -76,18 +76,6 @@ def test_with_no_task_memory_and_empty_tool_output_memory(self): assert agent.tools[0].input_memory[0] == agent.task_memory assert agent.tools[0].output_memory == {} - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - agent = Agent(tools=[MockTool()], embedding_driver=embedding_driver) - - storage = list(agent.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_without_default_task_memory(self): agent = Agent(task_memory=None, tools=[MockTool()]) @@ -233,7 +221,7 @@ def test_context(self): def test_task_memory_defaults(self): prompt_driver = MockPromptDriver() embedding_driver = MockEmbeddingDriver() - agent = Agent(prompt_driver=prompt_driver, embedding_driver=embedding_driver) + agent = Agent(prompt_driver=prompt_driver) storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) @@ -248,16 +236,6 @@ def test_task_memory_defaults(self): assert storage.csv_extraction_engine.prompt_driver == prompt_driver assert storage.json_extraction_engine.prompt_driver == prompt_driver - def test_deprecation(self): - with pytest.deprecated_call(): - Agent(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Agent(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Agent(stream=True) - def finished_tasks(self): task = PromptTask("test prompt") agent = Agent(prompt_driver=MockPromptDriver()) diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 306fd7bd2..a7f7f40c1 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -4,14 +4,11 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask from griptape.tokenizers import OpenAiTokenizer -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.unit.structures.test_agent import MockEmbeddingDriver class TestPipeline: @@ -31,10 +28,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - pipeline = Pipeline(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + pipeline = Pipeline(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert pipeline.prompt_driver is driver assert pipeline.input_task is None assert pipeline.output_task is None assert pipeline.rulesets[0].name == "TestRuleset" @@ -103,20 +98,6 @@ def test_with_task_memory(self): assert pipeline.tasks[0].tools[0].output_memory is not None assert pipeline.tasks[0].tools[0].output_memory["test"][0] == pipeline.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - pipeline = Pipeline(embedding_driver=embedding_driver) - - pipeline.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(pipeline.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): pipeline = Pipeline() @@ -139,7 +120,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline + [first_task, second_task, third_task] @@ -174,7 +155,7 @@ def test_tasks_order(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -189,7 +170,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + first_task pipeline + second_task @@ -208,7 +189,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] @@ -227,7 +208,7 @@ def test_insert_task_in_middle(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(first_task, third_task) @@ -251,7 +232,7 @@ def test_insert_task_at_end(self): second_task = PromptTask("test2", id="test2") third_task = PromptTask("test3", id="test3") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [first_task, second_task] pipeline.insert_task(second_task, third_task) @@ -271,7 +252,7 @@ def test_insert_task_at_end(self): assert [child.id for child in third_task.children] == [] def test_prompt_stack_without_memory(self): - pipeline = Pipeline(conversation_memory=None, prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(conversation_memory=None, rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -292,7 +273,7 @@ def test_prompt_stack_without_memory(self): assert len(task2.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), rules=[Rule("test")]) + pipeline = Pipeline(rules=[Rule("test")]) task1 = PromptTask("test") task2 = PromptTask("test") @@ -321,7 +302,7 @@ def test_text_artifact_token_count(self): def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task assert task.state == BaseTask.State.PENDING @@ -333,7 +314,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [task] pipeline._execution_args = ("test1", "test2") @@ -348,7 +329,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + [parent, task, child] @@ -365,35 +346,23 @@ def test_context(self): assert context["parent"] == parent assert context["child"] == child - def test_deprecation(self): - with pytest.deprecated_call(): - Pipeline(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Pipeline(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Pipeline(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task]) pipeline.run() assert pipeline.output is None def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") - pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + pipeline = Pipeline(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) pipeline.run() assert pipeline.output is not None def test_add_duplicate_task(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline + task @@ -402,7 +371,7 @@ def test_add_duplicate_task(self): def test_add_duplicate_task_directly(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline + task pipeline.tasks.append(task) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 242de29c5..79c9868e1 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -4,12 +4,9 @@ from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory -from griptape.memory.task.storage import TextArtifactStorage from griptape.rules import Rule, Ruleset from griptape.structures import Workflow from griptape.tasks import BaseTask, CodeExecutionTask, PromptTask, ToolkitTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -30,10 +27,8 @@ def fn(task): return CodeExecutionTask(run_fn=fn) def test_init(self): - driver = MockPromptDriver() - workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) + workflow = Workflow(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) - assert workflow.prompt_driver is driver assert len(workflow.tasks) == 0 assert workflow.rulesets[0].name == "TestRuleset" assert workflow.rulesets[0].rules[0].value == "test" @@ -100,20 +95,6 @@ def test_with_task_memory(self): assert workflow.tasks[0].tools[0].output_memory is not None assert workflow.tasks[0].tools[0].output_memory["test"][0] == workflow.task_memory - def test_embedding_driver(self): - embedding_driver = MockEmbeddingDriver() - workflow = Workflow(embedding_driver=embedding_driver) - - workflow.add_task(ToolkitTask(tools=[MockTool()])) - - storage = list(workflow.task_memory.artifact_storages.values())[0] - assert isinstance(storage, TextArtifactStorage) - memory_embedding_driver = storage.rag_engine.retrieval_stage.retrieval_modules[ - 0 - ].vector_store_driver.embedding_driver - - assert memory_embedding_driver == embedding_driver - def test_with_task_memory_and_empty_tool_output_memory(self): workflow = Workflow() @@ -136,7 +117,7 @@ def test_with_memory(self): second_task = PromptTask("test2") third_task = PromptTask("test3") - workflow = Workflow(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + workflow = Workflow(conversation_memory=ConversationMemory()) workflow + [first_task, second_task, third_task] @@ -170,7 +151,7 @@ def test_add_task(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + first_task workflow.add_task(second_task) @@ -189,7 +170,7 @@ def test_add_tasks(self): first_task = PromptTask("test1") second_task = PromptTask("test2") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [first_task, second_task] @@ -206,7 +187,7 @@ def test_add_tasks(self): def test_run(self): task1 = PromptTask("test") task2 = PromptTask("test") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + [task1, task2] assert task1.state == BaseTask.State.PENDING @@ -219,7 +200,7 @@ def test_run(self): def test_run_with_args(self): task = PromptTask("{{ args[0] }}-{{ args[1] }}") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task workflow._execution_args = ("test1", "test2") @@ -241,7 +222,7 @@ def test_run_with_args(self): ], ) def test_run_raises_on_missing_parent_or_child_id(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -250,7 +231,6 @@ def test_run_raises_on_missing_parent_or_child_id(self, tasks): def test_run_topology_1_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), @@ -265,7 +245,6 @@ def test_run_topology_1_declarative_parents(self): def test_run_topology_1_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task2", "task3"]), PromptTask("test2", id="task2", child_ids=["task4"]), @@ -280,7 +259,6 @@ def test_run_topology_1_declarative_children(self): def test_run_topology_1_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task3"]), PromptTask("test2", id="task2", parent_ids=["task1"], child_ids=["task4"]), @@ -301,7 +279,7 @@ def test_run_topology_1_imperative_parents(self): task2.add_parent(task1) task3.add_parent("task1") task4.add_parents([task2, "task3"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -315,14 +293,14 @@ def test_run_topology_1_imperative_children(self): task1.add_children([task2, task3]) task2.add_child(task4) task3.add_child(task4) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() self._validate_topology_1(workflow) def test_run_topology_1_imperative_parents_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -336,7 +314,7 @@ def test_run_topology_1_imperative_parents_structure_init(self): self._validate_topology_1(workflow) def test_run_topology_1_imperative_children_structure_init(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() task1 = PromptTask("test1", id="task1", structure=workflow) task2 = PromptTask("test2", id="task2", structure=workflow) task3 = PromptTask("test3", id="task3", structure=workflow) @@ -356,7 +334,7 @@ def test_run_topology_1_imperative_mixed(self): task4 = PromptTask("test4", id="task4") task1.add_children([task2, task3]) task4.add_parents([task2, task3]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + workflow = Workflow(tasks=[task1, task2, task3, task4]) workflow.run() @@ -367,7 +345,7 @@ def test_run_topology_1_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 splits into task2 and task3 # task2 and task3 converge into task4 @@ -384,7 +362,7 @@ def test_run_topology_1_missing_parent(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task1 never added to workflow workflow + task4 @@ -396,7 +374,7 @@ def test_run_topology_1_id_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() # task4 never added to workflow workflow + task1 @@ -410,7 +388,7 @@ def test_run_topology_1_object_equality(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -419,7 +397,6 @@ def test_run_topology_1_object_equality(self): def test_run_topology_2_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska"), PromptTask("testb", id="taskb", parent_ids=["taska"]), @@ -435,7 +412,6 @@ def test_run_topology_2_declarative_parents(self): def test_run_topology_2_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("testa", id="taska", child_ids=["taskb", "taskc", "taskd", "taske"]), PromptTask("testb", id="taskb", child_ids=["taskd"]), @@ -459,7 +435,7 @@ def test_run_topology_2_imperative_parents(self): taskc.add_parent("taska") taskd.add_parents([taska, taskb, taskc]) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -475,7 +451,7 @@ def test_run_topology_2_imperative_children(self): taskb.add_child(taskd) taskc.add_children([taskd, taske]) taskd.add_child(taske) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -491,7 +467,7 @@ def test_run_topology_2_imperative_mixed(self): taskb.add_child(taskd) taskd.add_parent(taskc) taske.add_parents(["taska", taskd, "taskc"]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + workflow = Workflow(tasks=[taska, taskb, taskc, taskd, taske]) workflow.run() @@ -503,7 +479,7 @@ def test_run_topology_2_imperative_insert(self): taskc = PromptTask("testc", id="taskc") taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_task(taska) workflow.add_task(taske) taske.add_parent(taska) @@ -517,7 +493,6 @@ def test_run_topology_2_imperative_insert(self): def test_run_topology_3_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"]), @@ -532,7 +507,6 @@ def test_run_topology_3_declarative_parents(self): def test_run_topology_3_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1", child_ids=["task4"]), PromptTask("test2", id="task2", child_ids=["task3"]), @@ -547,7 +521,6 @@ def test_run_topology_3_declarative_children(self): def test_run_topology_3_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task4"], child_ids=["task3"]), @@ -565,7 +538,7 @@ def test_run_topology_3_imperative_insert(self): task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") task4 = PromptTask("test4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task2 @@ -580,7 +553,6 @@ def test_run_topology_3_imperative_insert(self): def test_run_topology_4_declarative_parents(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"]), @@ -600,7 +572,6 @@ def test_run_topology_4_declarative_parents(self): def test_run_topology_4_declarative_children(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info", child_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), PromptTask(id="movie_info_1", child_ids=["compare_movies"]), @@ -620,7 +591,6 @@ def test_run_topology_4_declarative_children(self): def test_run_topology_4_declarative_mixed(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask(id="collect_movie_info"), PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), @@ -650,7 +620,7 @@ def test_run_topology_4_imperative_insert(self): publish_website = PromptTask(id="publish_website") movie_info_3 = PromptTask(id="movie_info_3") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow.add_tasks(collect_movie_info, summarize_to_slack) workflow.insert_tasks(collect_movie_info, [movie_info_1, movie_info_2, movie_info_3], summarize_to_slack) workflow.insert_tasks([movie_info_1, movie_info_2, movie_info_3], compare_movies, summarize_to_slack) @@ -672,7 +642,7 @@ def test_run_topology_4_imperative_insert(self): ], ) def test_run_raises_on_cycle(self, tasks): - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + workflow = Workflow(tasks=tasks) with pytest.raises(ValueError) as e: workflow.run() @@ -684,7 +654,7 @@ def test_input_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -697,7 +667,7 @@ def test_output_task(self): task2 = PromptTask("prompt2") task3 = PromptTask("prompt3") task4 = PromptTask("prompt4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -709,7 +679,7 @@ def test_output_task(self): task1.add_children([task2, task3]) # task4 is the final task, but its defined at index 0 - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3]) + workflow = Workflow(tasks=[task4, task1, task2, task3]) # output_task topologically should be task4 assert task4 == workflow.output_task @@ -719,7 +689,7 @@ def test_to_graph(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -736,7 +706,7 @@ def test_order_tasks(self): task2 = PromptTask("prompt2", id="task2") task3 = PromptTask("prompt3", id="task3") task4 = PromptTask("prompt4", id="task4") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + task1 workflow + task4 @@ -753,7 +723,7 @@ def test_context(self): parent = PromptTask("parent") task = PromptTask("test") child = PromptTask("child") - workflow = Workflow(prompt_driver=MockPromptDriver()) + workflow = Workflow() workflow + parent workflow + task @@ -776,20 +746,10 @@ def test_context(self): assert context["parents"] == {parent.id: parent} assert context["children"] == {child.id: child} - def test_deprecation(self): - with pytest.deprecated_call(): - Workflow(prompt_driver=MockPromptDriver()) - - with pytest.deprecated_call(): - Workflow(embedding_driver=MockEmbeddingDriver()) - - with pytest.deprecated_call(): - Workflow(stream=True) - def test_run_with_error_artifact(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task]) workflow.run() assert workflow.output is None @@ -797,9 +757,7 @@ def test_run_with_error_artifact(self, error_artifact_task, waiting_task): def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting_task): end_task = PromptTask("end") end_task.add_parents([error_artifact_task, waiting_task]) - workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False - ) + workflow = Workflow(tasks=[waiting_task, error_artifact_task, end_task], fail_fast=False) workflow.run() assert workflow.output is not None diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 734e111cf..33405ad10 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -6,8 +6,6 @@ from griptape.engines import AudioTranscriptionEngine from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestAudioTranscriptionTask: @@ -34,7 +32,7 @@ def callable_input(task: BaseTask) -> AudioArtifact: def test_config_audio_transcription_engine(self, audio_artifact): task = AudioTranscriptionTask(audio_artifact) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) @@ -42,7 +40,7 @@ def test_run(self, audio_artifact, audio_transcription_engine): audio_transcription_engine.run.return_value = TextArtifact("mock transcription") task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert pipeline.run().output.to_text() == "mock transcription" diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py index 3d8d67a55..8eaa832ae 100644 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ b/tests/unit/tasks/test_base_multi_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Pipeline from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestBaseMultiTextInputTask: @@ -42,7 +41,7 @@ def test_full_context(self): parent = MockMultiTextInputTask(("parent1", "parent2")) subtask = MockMultiTextInputTask(("test1", "test2"), context={"foo": "bar"}) child = MockMultiTextInputTask(("child2", "child2")) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index aa402bb48..1b45b4e98 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -7,8 +7,6 @@ from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask from tests.mocks.mock_tool.tool import MockTool @@ -18,10 +16,9 @@ class TestBaseTask: def task(self): event_bus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( - prompt_driver=MockPromptDriver(), - embedding_driver=MockEmbeddingDriver(), tools=[MockTool()], ) + event_bus.add_event_listeners([EventListener(handler=Mock())]) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) diff --git a/tests/unit/tasks/test_base_text_input_task.py b/tests/unit/tasks/test_base_text_input_task.py index 86dc98805..ff6afe42b 100644 --- a/tests/unit/tasks/test_base_text_input_task.py +++ b/tests/unit/tasks/test_base_text_input_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import TextArtifact from griptape.rules import Rule, Ruleset from griptape.structures import Pipeline -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_text_input_task import MockTextInputTask @@ -31,7 +30,7 @@ def test_full_context(self): parent = MockTextInputTask("parent") subtask = MockTextInputTask("test", context={"foo": "bar"}) child = MockTextInputTask("child") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_tasks(parent, subtask, child) diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index 3178e29db..e2c492fad 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -1,7 +1,6 @@ from griptape.artifacts import BaseArtifact, ErrorArtifact, TextArtifact from griptape.structures import Pipeline from griptape.tasks import CodeExecutionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver def hello_world(task: CodeExecutionTask) -> BaseArtifact: @@ -27,7 +26,7 @@ def test_hello_world_fn(self): # Using a Pipeline # Overriding the input because we are implementing the task not the Pipeline def test_noop_fn(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() task = CodeExecutionTask("No Op", run_fn=non_outputting) pipeline.add_task(task) temp = task.run() diff --git a/tests/unit/tasks/test_csv_extraction_task.py b/tests/unit/tasks/test_csv_extraction_task.py index 7d37c3897..ec8f70b23 100644 --- a/tests/unit/tasks/test_csv_extraction_task.py +++ b/tests/unit/tasks/test_csv_extraction_task.py @@ -4,7 +4,6 @@ from griptape.structures import Agent from griptape.tasks import CsvExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestCsvExtractionTask: @@ -13,7 +12,7 @@ def task(self): return CsvExtractionTask(args={"column_names": ["test1"]}) def test_run(self, task): - agent = Agent(config=MockStructureConfig()) + agent = Agent() agent.add_task(task) @@ -23,11 +22,7 @@ def test_run(self, task): assert result.value[0].value == {"test1": "mock output"} def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, CsvExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index afa73a506..76a4c3bd2 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -3,15 +3,12 @@ from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestExtractionTask: @pytest.fixture() def task(self): - return ExtractionTask( - extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), args={"column_names": ["test1"]} - ) + return ExtractionTask(extraction_engine=CsvExtractionEngine(), args={"column_names": ["test1"]}) def test_run(self, task): agent = Agent() diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index 447faa01c..01c116772 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, ImageQueryTask from tests.mocks.mock_image_query_driver import MockImageQueryDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestImageQueryTask: @@ -61,17 +60,11 @@ def test_list_input(self, text_artifact: TextArtifact, image_artifact: ImageArti def test_config_image_generation_engine(self, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_query_engine, ImageQueryEngine) assert isinstance(task.image_query_engine.image_query_driver, MockImageQueryDriver) - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = ImageQueryTask((text_artifact, [image_artifact, image_artifact])) - - with pytest.raises(ValueError, match="Image Query Engine"): - task.image_query_engine # noqa: B018 - def test_run(self, image_query_engine, text_artifact, image_artifact): task = ImageQueryTask((text_artifact, [image_artifact, image_artifact]), image_query_engine=image_query_engine) task.run() diff --git a/tests/unit/tasks/test_inpainting_image_generation_task.py b/tests/unit/tasks/test_inpainting_image_generation_task.py index 61c437bb7..5c4507d49 100644 --- a/tests/unit/tasks/test_inpainting_image_generation_task.py +++ b/tests/unit/tasks/test_inpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, InpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestInpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, InpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = InpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_json_extraction_task.py b/tests/unit/tasks/test_json_extraction_task.py index ba7d1ce30..8f9278c3c 100644 --- a/tests/unit/tasks/test_json_extraction_task.py +++ b/tests/unit/tasks/test_json_extraction_task.py @@ -5,7 +5,6 @@ from griptape.structures import Agent from griptape.tasks import JsonExtractionTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestJsonExtractionTask: @@ -13,11 +12,9 @@ class TestJsonExtractionTask: def task(self): return JsonExtractionTask("foo", args={"template_schema": Schema({"foo": "bar"}).json_schema("TemplateSchema")}) - def test_run(self, task): - mock_config = MockStructureConfig() - assert isinstance(mock_config.prompt_driver, MockPromptDriver) - mock_config.prompt_driver.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' - agent = Agent(config=mock_config) + def test_run(self, task, mock_config): + mock_config.drivers.prompt.mock_output = '[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]' + agent = Agent() agent.add_task(task) @@ -28,11 +25,7 @@ def test_run(self, task): assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_config_extraction_engine(self, task): - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.extraction_engine, JsonExtractionEngine) assert isinstance(task.extraction_engine.prompt_driver, MockPromptDriver) - - def test_missing_extraction_engine(self, task): - with pytest.raises(ValueError): - task.extraction_engine # noqa: B018 diff --git a/tests/unit/tasks/test_outpainting_image_generation_task.py b/tests/unit/tasks/test_outpainting_image_generation_task.py index 593451120..ba5e52a82 100644 --- a/tests/unit/tasks/test_outpainting_image_generation_task.py +++ b/tests/unit/tasks/test_outpainting_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, OutpaintingImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestOutpaintingImageGenerationTask: @@ -51,13 +50,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, OutpaintingImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_image_generation_engine(self, text_artifact, image_artifact): - task = OutpaintingImageGenerationTask((text_artifact, image_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_image_generation_task.py b/tests/unit/tasks/test_prompt_image_generation_task.py index 1c4b639fb..3ad0302f2 100644 --- a/tests/unit/tasks/test_prompt_image_generation_task.py +++ b/tests/unit/tasks/test_prompt_image_generation_task.py @@ -1,13 +1,10 @@ from unittest.mock import Mock -import pytest - from griptape.artifacts import TextArtifact from griptape.engines import PromptImageGenerationEngine from griptape.structures import Agent from griptape.tasks import BaseTask, PromptImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptImageGenerationTask: @@ -28,13 +25,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_image_generation_engine_engine(self): task = PromptImageGenerationTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, PromptImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self): - task = PromptImageGenerationTask("foo bar") - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 083ea6da5..cfe853226 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,5 +1,3 @@ -import pytest - from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.list_artifact import ListArtifact from griptape.artifacts.text_artifact import TextArtifact @@ -9,13 +7,12 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestPromptTask: def test_run(self): task = PromptTask("test") - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) @@ -30,16 +27,10 @@ def test_to_text(self): def test_config_prompt_driver(self): task = PromptTask("test") - Pipeline(config=MockStructureConfig()).add_task(task) + Pipeline().add_task(task) assert isinstance(task.prompt_driver, MockPromptDriver) - def test_missing_prompt_driver(self): - task = PromptTask("test") - - with pytest.raises(ValueError): - task.prompt_driver # noqa: B018 - def test_input(self): # Str task = PromptTask("test") diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 1053ade9e..d18d75d75 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -5,9 +5,11 @@ class TestStructureRunTask: - def test_run(self): - agent = Agent(prompt_driver=MockPromptDriver(mock_output="agent mock output")) - pipeline = Pipeline(prompt_driver=MockPromptDriver(mock_output="pipeline mock output")) + def test_run(self, mock_config): + mock_config.drivers.prompt = MockPromptDriver(mock_output="agent mock output") + agent = Agent() + mock_config.drivers.prompt = MockPromptDriver(mock_output="pipeline mock output") + pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) task = StructureRunTask(driver=driver) diff --git a/tests/unit/tasks/test_text_summary_task.py b/tests/unit/tasks/test_text_summary_task.py index bb08f9d31..f83075f2a 100644 --- a/tests/unit/tasks/test_text_summary_task.py +++ b/tests/unit/tasks/test_text_summary_task.py @@ -1,15 +1,12 @@ -import pytest - from griptape.engines import PromptSummaryEngine from griptape.structures import Agent from griptape.tasks import TextSummaryTask from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextSummaryTask: def test_run(self): - task = TextSummaryTask("test", summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver())) + task = TextSummaryTask("test", summary_engine=PromptSummaryEngine()) agent = Agent() agent.add_task(task) @@ -26,13 +23,7 @@ def test_context_propagation(self): def test_config_summary_engine(self): task = TextSummaryTask("test") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.summary_engine, PromptSummaryEngine) assert isinstance(task.summary_engine.prompt_driver, MockPromptDriver) - - def test_missing_summary_engine(self): - task = TextSummaryTask("test") - - with pytest.raises(ValueError): - task.summary_engine # noqa: B018 diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index bf1f19d5a..44348fef0 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -4,8 +4,6 @@ from griptape.engines import TextToSpeechEngine from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestTextToSpeechTask: @@ -26,7 +24,7 @@ def callable_input(task: BaseTask) -> TextArtifact: def test_config_text_to_speech_engine(self): task = TextToSpeechTask("foo bar") - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) @@ -41,7 +39,7 @@ def test_run(self): text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) - pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline = Pipeline() pipeline.add_task(task) assert isinstance(pipeline.run().output, AudioArtifact) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index dfc679919..18521632e 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -5,7 +5,6 @@ from griptape.artifacts import TextArtifact from griptape.structures import Agent from griptape.tasks import ActionsSubtask, ToolTask -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -166,13 +165,12 @@ class TestToolTask: } @pytest.fixture() - def agent(self): + def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - return Agent( - prompt_driver=MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}"), - embedding_driver=MockEmbeddingDriver(), - ) + mock_config.drivers.prompt = MockPromptDriver(mock_output=f"```python foo bar\n{json.dumps(output_dict)}") + + return Agent() def test_run_without_memory(self, agent): task = ToolTask(tool=MockTool()) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index cd5dd21f8..c1b91b1ed 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -2,7 +2,6 @@ from griptape.common import ToolAction from griptape.structures import Agent from griptape.tasks import ActionsSubtask, PromptTask, ToolkitTask -from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -170,11 +169,12 @@ def test_init(self): except ValueError: assert True - def test_run(self): + def test_run(self, mock_config): output = """Answer: done""" + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -184,11 +184,12 @@ def test_run(self): assert len(task.subtasks) == 1 assert result.output_task.output.to_text() == "done" - def test_run_max_subtasks(self): + def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) @@ -197,11 +198,12 @@ def test_run_max_subtasks(self): assert len(task.subtasks) == 3 assert isinstance(task.output, ErrorArtifact) - def test_run_invalid_react_prompt(self): + def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" + mock_config.drivers.prompt.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) + agent = Agent() agent.add_task(task) diff --git a/tests/unit/tasks/test_variation_image_generation_task.py b/tests/unit/tasks/test_variation_image_generation_task.py index a910fb8e0..f6afbf03e 100644 --- a/tests/unit/tasks/test_variation_image_generation_task.py +++ b/tests/unit/tasks/test_variation_image_generation_task.py @@ -8,7 +8,6 @@ from griptape.structures import Agent from griptape.tasks import BaseTask, VariationImageGenerationTask from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver -from tests.mocks.mock_structure_config import MockStructureConfig class TestVariationImageGenerationTask: @@ -48,13 +47,7 @@ def test_bad_input(self, image_artifact): def test_config_image_generation_engine(self, text_artifact, image_artifact): task = VariationImageGenerationTask((text_artifact, image_artifact)) - Agent(config=MockStructureConfig()).add_task(task) + Agent().add_task(task) assert isinstance(task.image_generation_engine, VariationImageGenerationEngine) assert isinstance(task.image_generation_engine.image_generation_driver, MockImageGenerationDriver) - - def test_missing_summary_engine(self, text_artifact, image_artifact): - task = VariationImageGenerationTask((text_artifact, image_artifact)) - - with pytest.raises(ValueError): - task.image_generation_engine # noqa: B018 diff --git a/tests/unit/tools/test_structure_run_client.py b/tests/unit/tools/test_structure_run_client.py index d498b7c56..ee76d4da1 100644 --- a/tests/unit/tools/test_structure_run_client.py +++ b/tests/unit/tools/test_structure_run_client.py @@ -3,14 +3,12 @@ from griptape.drivers.structure_run.local_structure_run_driver import LocalStructureRunDriver from griptape.structures import Agent from griptape.tools import StructureRunClient -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureRunClient: @pytest.fixture() def client(self): - driver = MockPromptDriver() - agent = Agent(prompt_driver=driver) + agent = Agent() return StructureRunClient( description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 42ecc59c3..5f97d1baf 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -1,14 +1,13 @@ from griptape.memory.structure import ConversationMemory from griptape.structures import Agent from griptape.utils import Chat -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_init(self): import logging - agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + agent = Agent(conversation_memory=ConversationMemory()) chat = Chat( agent, diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index 28ee72409..a07d15cdb 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -2,12 +2,11 @@ from griptape.structures import Pipeline from griptape.tasks import PromptTask from griptape.utils import Conversation -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestConversation: def test_lines(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -22,7 +21,7 @@ def test_lines(self): assert lines[3] == "A: mock output" def test_prompt_stack_conversation_memory(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -36,8 +35,7 @@ def test_prompt_stack_conversation_memory(self): def test_prompt_stack_summary_conversation_memory(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), - conversation_memory=SummaryConversationMemory(summary="foobar", prompt_driver=MockPromptDriver()), + conversation_memory=SummaryConversationMemory(summary="foobar"), ) pipeline.add_tasks(PromptTask("question 1")) @@ -52,7 +50,7 @@ def test_prompt_stack_summary_conversation_memory(self): assert lines[2] == "assistant: mock output" def test___str__(self): - pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) + pipeline = Pipeline(conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index a9c122126..00df6958d 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -3,7 +3,6 @@ from griptape import utils from griptape.loaders import TextLoader -from tests.mocks.mock_embedding_driver import MockEmbeddingDriver MAX_TOKENS = 50 @@ -32,7 +31,7 @@ def test_load_files(self): def test_load_file_with_loader(self): dirname = os.path.dirname(__file__) file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) - artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file) + artifacts = TextLoader(max_tokens=MAX_TOKENS).load(file) assert len(artifacts) == 39 assert isinstance(artifacts, list) @@ -43,7 +42,7 @@ def test_load_files_with_loader(self): sources = ["resources/foobar-many.txt"] sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources) - loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) + loader = TextLoader(max_tokens=MAX_TOKENS) collection = loader.load_collection(list(files.values())) test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index da6695139..caddbb1a3 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -2,18 +2,17 @@ import pytest -from griptape.structures import Agent +from griptape.structures import Agent, Pipeline from griptape.utils import Stream -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - return Agent(prompt_driver=MockPromptDriver(stream=request.param, max_attempts=0)) + return Agent(stream=request.param) def test_init(self, agent): - if agent.prompt_driver.stream: + if agent.stream: chat_stream = Stream(agent) assert chat_stream.structure == agent @@ -28,3 +27,9 @@ def test_init(self, agent): else: with pytest.raises(ValueError): Stream(agent) + + def test_validate_structure_invalid(self): + pipeline = Pipeline(tasks=[]) + + with pytest.raises(ValueError): + Stream(pipeline) diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index f6e621b91..8a055cb21 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -1,12 +1,11 @@ from griptape.structures import Agent, Pipeline, Workflow from griptape.tasks import PromptTask from griptape.utils import StructureVisualizer -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestStructureVisualizer: def test_agent(self): - agent = Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask("test1", id="task1")]) + agent = Agent(tasks=[PromptTask("test1", id="task1")]) visualizer = StructureVisualizer(agent) result = visualizer.to_url() @@ -15,7 +14,6 @@ def test_agent(self): def test_pipeline(self): pipeline = Pipeline( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2"), @@ -34,7 +32,6 @@ def test_pipeline(self): def test_workflow(self): workflow = Workflow( - prompt_driver=MockPromptDriver(), tasks=[ PromptTask("test1", id="task1"), PromptTask("test2", id="task2", parent_ids=["task1"]), diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index 154f63ac4..8e9d45fb7 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -17,9 +17,9 @@ def text_tool_artifact_storage(): rag_engine=rag_engine(MockPromptDriver(), vector_store_driver), vector_store_driver=vector_store_driver, retrieval_rag_module_name="VectorStoreRetrievalRagModule", - summary_engine=PromptSummaryEngine(prompt_driver=MockPromptDriver()), - csv_extraction_engine=CsvExtractionEngine(prompt_driver=MockPromptDriver()), - json_extraction_engine=JsonExtractionEngine(prompt_driver=MockPromptDriver()), + summary_engine=PromptSummaryEngine(), + csv_extraction_engine=CsvExtractionEngine(), + json_extraction_engine=JsonExtractionEngine(), ) diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 5b908065b..d87fc095e 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -25,9 +25,7 @@ def get_enabled_prompt_drivers(prompt_drivers_options) -> list[BasePromptDriver]: return [ - prompt_driver_option.prompt_driver - for prompt_driver_option in prompt_drivers_options - if prompt_driver_option.enabled + prompt_driver_option.prompt for prompt_driver_option in prompt_drivers_options if prompt_driver_option.enabled ] @@ -228,6 +226,15 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: + from griptape.config import config + + config.drivers.prompt = AzureOpenAiChatPromptDriver( + api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + model="gpt-4o", + azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], + response_format="json_object", + ) output_schema = Schema( { Literal("correct", description="Whether the output was correct or not."): bool, @@ -265,13 +272,6 @@ def verify_structure_output(self, structure) -> dict: ], ), ], - prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="gpt-4o", - azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - response_format="json_object", - ), tasks=[ PromptTask( "\nTasks: {{ task_names }}" diff --git a/tests/utils/test_reference_utils.py b/tests/utils/test_reference_utils.py index c3491f5d0..47da18713 100644 --- a/tests/utils/test_reference_utils.py +++ b/tests/utils/test_reference_utils.py @@ -1,12 +1,11 @@ from griptape.artifacts import TextArtifact from griptape.common import Reference from griptape.engines.rag.modules import PromptResponseRagModule -from tests.mocks.mock_prompt_driver import MockPromptDriver class TestReferenceUtils: def test_references_from_artifacts(self): - module = PromptResponseRagModule(prompt_driver=MockPromptDriver()) + module = PromptResponseRagModule() reference1 = Reference(title="foo") reference2 = Reference(title="bar") artifacts = [