diff --git a/CHANGELOG.md b/CHANGELOG.md index a45f76053..21f72fe50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `structure` to `BaseTask`. - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. -- Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. -- Global config, `griptape.config.config`, for setting global configuration defaults. +- Global event bus, `griptape.events.EventBus`, for publishing and subscribing to events. +- Global object, `griptape.configs.Defaults`, for setting default values throughout the framework. - Unique name generation for all `RagEngine` modules. - `ExtractionTool` for having the LLM extract structured data from text. - `PromptSummaryTool` for having the LLM summarize text. @@ -27,14 +27,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `griptape.utils.decorators.lazy_property` for creating lazy properties. ### Changed -- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. +- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `EventBus`. - **BREAKING**: Removed `EventPublisherMixin`. -- **BREAKING**: Removed `Pipeline.prompt_driver` and `Workflow.prompt_driver`. `Agent.prompt_driver` has not been removed. -- **BREAKING**: Removed `Pipeline.stream` and `Workflow.stream`. `Agent.stream` has not been removed. -- **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.config.config.drivers.embedding` instead. -- **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `griptape.config.config.logger` instead. +- **BREAKING**: Removed `Pipeline.prompt_driver` and `Workflow.prompt_driver`. Set this via `griptape.configs.Defaults.drivers.prompt_driver` instead. `Agent.prompt_driver` has not been removed. +- **BREAKING**: Removed `Pipeline.stream` and `Workflow.stream`. Set this via `griptape.configs.Defaults.drivers.prompt_driver.stream` instead. `Agent.stream` has not been removed. +- **BREAKING**: Removed `Structure.embedding_driver`, set this via `griptape.configs.Defaults.drivers.embedding_driver` instead. +- **BREAKING**: Removed `Structure.custom_logger` and `Structure.logger_level`, set these via `logging.getLogger(griptape.configs.Defaults.logger_name)` instead. - **BREAKING**: Removed `BaseStructureConfig.merge_config`. -- **BREAKING**: Renamed `StructureConfig` to `DriverConfig`, moved to `griptape.config.drivers` and renamed fields accordingly. +- **BREAKING**: Renamed `StructureConfig` to `DriversConfig`, moved to `griptape.configs.drivers` and renamed fields accordingly. - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. - **BREAKING**: Removed before and after response modules from `ResponseRagStage`. - **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. @@ -53,7 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `JsonExtractionTask`, and `CsvExtractionTask` use `ExtractionTask` instead. - **BREAKING**: Removed `TaskMemoryClient`, use `QueryClient`, `ExtractionTool`, or `PromptSummaryTool` instead. - **BREAKING**: `BaseTask.add_parent/child` now take a `BaseTask` instead of `str | BaseTask`. -- Engines that previously required Drivers now pull from `griptape.config.config.drivers` by default. +- Engines that previously required Drivers now pull from `griptape.configs.Defaults.drivers_config` by default. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. - `BaseTask.add_parent/child` now returns `self`, allowing for chaining. - `Chat` now sets the `griptape` logger level to `logging.ERROR`, suppressing all logs except for errors. diff --git a/docs/examples/src/multiple_agent_shared_memory_1.py b/docs/examples/src/multiple_agent_shared_memory_1.py index b6089c190..e09f29ab7 100644 --- a/docs/examples/src/multiple_agent_shared_memory_1.py +++ b/docs/examples/src/multiple_agent_shared_memory_1.py @@ -1,7 +1,7 @@ import os -from griptape.config import config -from griptape.config.drivers import AzureOpenAiDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import AzureOpenAiDriversConfig from griptape.drivers import AzureMongoDbVectorStoreDriver, AzureOpenAiEmbeddingDriver from griptape.structures import Agent from griptape.tools import QueryTool, WebScraperTool @@ -34,7 +34,7 @@ vector_path=MONGODB_VECTOR_PATH, ) -config.driver_config = AzureOpenAiDriverConfig( +Defaults.drivers_config = AzureOpenAiDriversConfig( azure_endpoint=AZURE_OPENAI_ENDPOINT_1, vector_store_driver=mongo_driver, embedding_driver=embedding_driver, diff --git a/docs/examples/src/talk_to_a_video_1.py b/docs/examples/src/talk_to_a_video_1.py index 2748902a2..d23c906b6 100644 --- a/docs/examples/src/talk_to_a_video_1.py +++ b/docs/examples/src/talk_to_a_video_1.py @@ -3,11 +3,11 @@ import google.generativeai as genai from griptape.artifacts import GenericArtifact, TextArtifact -from griptape.config import config -from griptape.config.drivers import GoogleDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import GoogleDriversConfig from griptape.structures import Agent -config.driver_config = GoogleDriverConfig() +Defaults.drivers_config = GoogleDriversConfig() video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4") while video_file.state.name == "PROCESSING": diff --git a/docs/griptape-framework/drivers/src/embedding_drivers_10.py b/docs/griptape-framework/drivers/src/embedding_drivers_10.py index 2705dcfad..605b6e67a 100644 --- a/docs/griptape-framework/drivers/src/embedding_drivers_10.py +++ b/docs/griptape-framework/drivers/src/embedding_drivers_10.py @@ -1,5 +1,5 @@ -from griptape.config import config -from griptape.config.drivers import DriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( OpenAiChatPromptDriver, VoyageAiEmbeddingDriver, @@ -7,12 +7,12 @@ from griptape.structures import Agent from griptape.tools import PromptSummaryTool, WebScraperTool -config.driver_config = DriverConfig( +Defaults.drivers_config = DriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ) -config.driver_config = DriverConfig( +Defaults.drivers_config = DriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), embedding_driver=VoyageAiEmbeddingDriver(), ) diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_1.py b/docs/griptape-framework/drivers/src/event_listener_drivers_1.py index 01b02cf76..66b9372c3 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_1.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_1.py @@ -1,11 +1,11 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from griptape.rules import Rule from griptape.structures import Agent -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_3.py b/docs/griptape-framework/drivers/src/event_listener_drivers_3.py index 3a2ba8560..0bb248362 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_3.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_3.py @@ -1,11 +1,11 @@ import os from griptape.drivers import AmazonSqsEventListenerDriver -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from griptape.rules import Rule from griptape.structures import Agent -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( driver=AmazonSqsEventListenerDriver( diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py index ff59be794..6d03d2ce3 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_4.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_4.py @@ -1,14 +1,14 @@ import os -from griptape.config import config -from griptape.config.drivers import DriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import DriversConfig from griptape.drivers import AwsIotCoreEventListenerDriver, OpenAiChatPromptDriver -from griptape.events import EventListener, FinishStructureRunEvent, event_bus +from griptape.events import EventBus, EventListener, FinishStructureRunEvent from griptape.rules import Rule from griptape.structures import Agent -config.driver_config = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) -event_bus.add_event_listeners( +Defaults.drivers_config = DriversConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7)) +EventBus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_5.py b/docs/griptape-framework/drivers/src/event_listener_drivers_5.py index b72ca19d4..27186e229 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_5.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_5.py @@ -1,8 +1,8 @@ from griptape.drivers import GriptapeCloudEventListenerDriver -from griptape.events import EventListener, FinishStructureRunEvent, event_bus +from griptape.events import EventBus, EventListener, FinishStructureRunEvent from griptape.structures import Agent -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_6.py b/docs/griptape-framework/drivers/src/event_listener_drivers_6.py index 0c594bf12..c60cc6984 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_6.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_6.py @@ -1,10 +1,10 @@ import os from griptape.drivers import WebhookEventListenerDriver -from griptape.events import EventListener, FinishStructureRunEvent, event_bus +from griptape.events import EventBus, EventListener, FinishStructureRunEvent from griptape.structures import Agent -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/drivers/src/event_listener_drivers_7.py b/docs/griptape-framework/drivers/src/event_listener_drivers_7.py index 0d4f6abf5..c010cb8f9 100644 --- a/docs/griptape-framework/drivers/src/event_listener_drivers_7.py +++ b/docs/griptape-framework/drivers/src/event_listener_drivers_7.py @@ -1,10 +1,10 @@ import os from griptape.drivers import PusherEventListenerDriver -from griptape.events import EventListener, FinishStructureRunEvent, event_bus +from griptape.events import EventBus, EventListener, FinishStructureRunEvent from griptape.structures import Agent -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( event_types=[FinishStructureRunEvent], diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index a94746f99..3c4181aee 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -5,7 +5,7 @@ search: ## Overview -You can configure the global [event_bus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. +You can configure the global [EventBus](../../reference/griptape/events/event_bus.md) with [EventListener](../../reference/griptape/events/event_listener.md)s to listen for various framework events. See [Event Listener Drivers](../drivers/event-listener-drivers.md) for examples on forwarding events to external services. ## Specific Event Types diff --git a/docs/griptape-framework/misc/src/events_1.py b/docs/griptape-framework/misc/src/events_1.py index a69d63be0..993567cc6 100644 --- a/docs/griptape-framework/misc/src/events_1.py +++ b/docs/griptape-framework/misc/src/events_1.py @@ -1,5 +1,6 @@ from griptape.events import ( BaseEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -7,7 +8,6 @@ StartActionsSubtaskEvent, StartPromptEvent, StartTaskEvent, - event_bus, ) from griptape.structures import Agent @@ -16,7 +16,7 @@ def handler(event: BaseEvent) -> None: print(event.__class__) -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( handler, diff --git a/docs/griptape-framework/misc/src/events_2.py b/docs/griptape-framework/misc/src/events_2.py index be92bfb37..7c3a967fe 100644 --- a/docs/griptape-framework/misc/src/events_2.py +++ b/docs/griptape-framework/misc/src/events_2.py @@ -1,4 +1,4 @@ -from griptape.events import BaseEvent, EventListener, event_bus +from griptape.events import BaseEvent, EventBus, EventListener from griptape.structures import Agent @@ -10,7 +10,7 @@ def handler2(event: BaseEvent) -> None: print("Handler 2", event.__class__) -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener(handler1), EventListener(handler2), diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index 721cf3511..7adac812f 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,12 +1,12 @@ from typing import cast from griptape.drivers import OpenAiChatPromptDriver -from griptape.events import CompletionChunkEvent, EventListener, event_bus +from griptape.events import CompletionChunkEvent, EventBus, EventListener from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import PromptSummaryTool, WebScraperTool -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True), diff --git a/docs/griptape-framework/misc/src/events_4.py b/docs/griptape-framework/misc/src/events_4.py index 27bb3c5a8..eba11b07a 100644 --- a/docs/griptape-framework/misc/src/events_4.py +++ b/docs/griptape-framework/misc/src/events_4.py @@ -1,12 +1,12 @@ import logging -from griptape.config import config +from griptape.configs import Defaults from griptape.structures import Agent from griptape.tools import PromptSummaryTool, WebScraperTool from griptape.utils import Stream # Hide Griptape's usual output -logging.getLogger(config.logging_config.logger_name).setLevel(logging.ERROR) +logging.getLogger(Defaults.logging_config.logger_name).setLevel(logging.ERROR) agent = Agent( input="Based on https://griptape.ai, tell me what griptape is.", diff --git a/docs/griptape-framework/misc/src/events_5.py b/docs/griptape-framework/misc/src/events_5.py index 14b3f531e..65bdcec4d 100644 --- a/docs/griptape-framework/misc/src/events_5.py +++ b/docs/griptape-framework/misc/src/events_5.py @@ -1,5 +1,5 @@ from griptape import utils -from griptape.events import BaseEvent, EventListener, FinishPromptEvent, event_bus +from griptape.events import BaseEvent, EventBus, EventListener, FinishPromptEvent from griptape.structures import Agent token_counter = utils.TokenCounter() @@ -10,7 +10,7 @@ def count_tokens(e: BaseEvent) -> None: token_counter.add_tokens(e.output_token_count) -event_bus.add_event_listeners( +EventBus.add_event_listeners( [ EventListener( count_tokens, diff --git a/docs/griptape-framework/misc/src/events_6.py b/docs/griptape-framework/misc/src/events_6.py index aae9c2078..25934442a 100644 --- a/docs/griptape-framework/misc/src/events_6.py +++ b/docs/griptape-framework/misc/src/events_6.py @@ -1,7 +1,7 @@ -from griptape.events import BaseEvent, EventListener, StartPromptEvent, event_bus +from griptape.events import BaseEvent, EventBus, EventListener, StartPromptEvent from griptape.structures import Agent -event_bus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) +EventBus.add_event_listeners([EventListener(handler=lambda e: print(e), event_types=[StartPromptEvent])]) def handler(event: BaseEvent) -> None: diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md deleted file mode 100644 index 67721ebb1..000000000 --- a/docs/griptape-framework/structures/config.md +++ /dev/null @@ -1,89 +0,0 @@ ---- -search: - boost: 2 ---- - -## Overview - -Griptape exposes global configuration options to easily customize different parts of the framework. - -### Driver Configs - -The [DriverConfig](../../reference/griptape/config/drivers/driver_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. - -Griptape provides predefined [DriverConfig](../../reference/griptape/config/drivers/driver_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. - -#### OpenAI - -The [OpenAI Driver config](../../reference/griptape/config/drivers/openai_driver_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. - -```python ---8<-- "docs/griptape-framework/structures/src/config_1.py" -``` - -#### Azure OpenAI - -The [Azure OpenAI Driver config](../../reference/griptape/config/drivers/azure_openai_driver_config.md) provides default Drivers for Azure's OpenAI APIs. - -```python ---8<-- "docs/griptape-framework/structures/src/config_2.py" -``` - -#### Amazon Bedrock -The [Amazon Bedrock Driver config](../../reference/griptape/config/drivers/amazon_bedrock_driver_config.md) provides default Drivers for Amazon Bedrock's APIs. - -```python ---8<-- "docs/griptape-framework/structures/src/config_3.py" -``` - -#### Google -The [Google Driver config](../../reference/griptape/config/drivers/google_driver_config.md) provides default Drivers for Google's Gemini APIs. - -```python ---8<-- "docs/griptape-framework/structures/src/config_4.py" -``` - -#### Anthropic - -The [Anthropic Driver config](../../reference/griptape/config/drivers/anthropic_driver_config.md) provides default Drivers for Anthropic's APIs. - -!!! info - Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. - The `AnthropicDriverConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). - To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). - -```python ---8<-- "docs/griptape-framework/structures/src/config_5.py" -``` - -#### Cohere - -The [Cohere Driver config](../../reference/griptape/config/drivers/cohere_driver_config.md) provides default Drivers for Cohere's APIs. - -```python ---8<-- "docs/griptape-framework/structures/src/config_6.py" -``` - -#### Custom - -You can create your own [DriverConfig](../../reference/griptape/config/drivers/driver_config.md) by overriding relevant Drivers. -The [DriverConfig](../../reference/griptape/config/drivers/driver_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. -This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. - -```python ---8<-- "docs/griptape-framework/structures/src/config_7.py" -``` - -### Logging Config - -Griptape provides a predefined [LoggingConfig](../../reference/griptape/config/logging/logging_config.md)'s for easily customizing the logging events that the framework emits. In order to customize the logger, the logger can be fetched by using the `config.logging.logger_name`. - -```python ---8<-- "docs/griptape-framework/structures/src/config_logging.py" -``` - -### Loading/Saving Configs - -```python ---8<-- "docs/griptape-framework/structures/src/config_8.py" -``` diff --git a/docs/griptape-framework/structures/configs.md b/docs/griptape-framework/structures/configs.md new file mode 100644 index 000000000..2a9b5c62d --- /dev/null +++ b/docs/griptape-framework/structures/configs.md @@ -0,0 +1,89 @@ +--- +search: + boost: 2 +--- + +## Overview + +Griptape exposes global configuration options to easily customize different parts of the framework. + +### Drivers Configs + +The [DriversConfig](../../reference/griptape/configs/drivers/drivers_config.md) class allows for the customization of Structures within Griptape, enabling specific settings such as Drivers to be defined for Tasks. + +Griptape provides predefined [DriversConfig](../../reference/griptape/configs/drivers/drivers_config.md)'s for widely used services that provide APIs for most Driver types Griptape offers. + +#### OpenAI + +The [OpenAI Driver config](../../reference/griptape/configs/drivers/openai_drivers_config.md) provides default Drivers for OpenAI's APIs. This is the default config for all Structures. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_1.py" +``` + +#### Azure OpenAI + +The [Azure OpenAI Driver config](../../reference/griptape/configs/drivers/azure_openai_drivers_config.md) provides default Drivers for Azure's OpenAI APIs. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_2.py" +``` + +#### Amazon Bedrock +The [Amazon Bedrock Driver config](../../reference/griptape/configs/drivers/amazon_bedrock_drivers_config.md) provides default Drivers for Amazon Bedrock's APIs. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_3.py" +``` + +#### Google +The [Google Driver config](../../reference/griptape/configs/drivers/google_drivers_config.md) provides default Drivers for Google's Gemini APIs. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_4.py" +``` + +#### Anthropic + +The [Anthropic Driver config](../../reference/griptape/configs/drivers/anthropic_drivers_config.md) provides default Drivers for Anthropic's APIs. + +!!! info + Anthropic does not provide an embeddings API which means you will need to use another service for embeddings. + The `AnthropicDriversConfig` defaults to using `VoyageAiEmbeddingDriver` which integrates with [VoyageAI](https://www.voyageai.com/), the service used in Anthropic's [embeddings documentation](https://docs.anthropic.com/claude/docs/embeddings). + To override the default embedding driver, see: [Override Default Structure Embedding Driver](../drivers/embedding-drivers.md#override-default-structure-embedding-driver). + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_5.py" +``` + +#### Cohere + +The [Cohere Driver config](../../reference/griptape/configs/drivers/cohere_drivers_config.md) provides default Drivers for Cohere's APIs. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_6.py" +``` + +#### Custom + +You can create your own [DriversConfig](../../reference/griptape/configs/drivers/drivers_config.md) by overriding relevant Drivers. +The [DriversConfig](../../reference/griptape/configs/drivers/drivers_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. +This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_7.py" +``` + +### Logging Config + +Griptape provides a predefined [LoggingConfig](../../reference/griptape/configs/logging/logging_config.md)'s for easily customizing the logging events that the framework emits. In order to customize the logger, the logger can be fetched by using the `Defaults.logging.logger_name`. + +```python +--8<-- "docs/griptape-framework/structures/src/logging_config.py" +``` + +### Loading/Saving Configs + +```python +--8<-- "docs/griptape-framework/structures/src/drivers_config_8.py" +``` diff --git a/docs/griptape-framework/structures/src/config_1.py b/docs/griptape-framework/structures/src/config_1.py deleted file mode 100644 index e038130c2..000000000 --- a/docs/griptape-framework/structures/src/config_1.py +++ /dev/null @@ -1,7 +0,0 @@ -from griptape.config import config -from griptape.config.drivers import OpenAiDriverConfig -from griptape.structures import Agent - -config.driver_config = OpenAiDriverConfig() - -agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_4.py b/docs/griptape-framework/structures/src/config_4.py deleted file mode 100644 index e97422388..000000000 --- a/docs/griptape-framework/structures/src/config_4.py +++ /dev/null @@ -1,7 +0,0 @@ -from griptape.config import config -from griptape.config.drivers import GoogleDriverConfig -from griptape.structures import Agent - -config.driver_config = GoogleDriverConfig() - -agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_5.py b/docs/griptape-framework/structures/src/config_5.py deleted file mode 100644 index 519b770df..000000000 --- a/docs/griptape-framework/structures/src/config_5.py +++ /dev/null @@ -1,7 +0,0 @@ -from griptape.config import config -from griptape.config.drivers import AnthropicDriverConfig -from griptape.structures import Agent - -config.driver_config = AnthropicDriverConfig() - -agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_6.py b/docs/griptape-framework/structures/src/config_6.py deleted file mode 100644 index c53d8c1b0..000000000 --- a/docs/griptape-framework/structures/src/config_6.py +++ /dev/null @@ -1,9 +0,0 @@ -import os - -from griptape.config import config -from griptape.config.drivers import CohereDriverConfig -from griptape.structures import Agent - -config.driver_config = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"]) - -agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_logging.py b/docs/griptape-framework/structures/src/config_logging.py deleted file mode 100644 index 4dceb6edd..000000000 --- a/docs/griptape-framework/structures/src/config_logging.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from griptape.config import config -from griptape.config.drivers import OpenAiDriverConfig -from griptape.config.logging import TruncateLoggingFilter -from griptape.structures import Agent - -config.driver_config = OpenAiDriverConfig() - -logger = logging.getLogger(config.logging_config.logger_name) -logger.setLevel(logging.ERROR) -logger.addFilter(TruncateLoggingFilter(max_log_length=100)) - -agent = Agent() diff --git a/docs/griptape-framework/structures/src/drivers_config_1.py b/docs/griptape-framework/structures/src/drivers_config_1.py new file mode 100644 index 000000000..c156f8594 --- /dev/null +++ b/docs/griptape-framework/structures/src/drivers_config_1.py @@ -0,0 +1,7 @@ +from griptape.configs import Defaults +from griptape.configs.drivers import OpenAiDriversConfig +from griptape.structures import Agent + +Defaults.drivers_config = OpenAiDriversConfig() + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_2.py b/docs/griptape-framework/structures/src/drivers_config_2.py similarity index 53% rename from docs/griptape-framework/structures/src/config_2.py rename to docs/griptape-framework/structures/src/drivers_config_2.py index a187e8c06..b115a22f4 100644 --- a/docs/griptape-framework/structures/src/config_2.py +++ b/docs/griptape-framework/structures/src/drivers_config_2.py @@ -1,10 +1,10 @@ import os -from griptape.config import config -from griptape.config.drivers import AzureOpenAiDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import AzureOpenAiDriversConfig from griptape.structures import Agent -config.driver_config = AzureOpenAiDriverConfig( +Defaults.drivers_config = AzureOpenAiDriversConfig( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"] ) diff --git a/docs/griptape-framework/structures/src/config_3.py b/docs/griptape-framework/structures/src/drivers_config_3.py similarity index 65% rename from docs/griptape-framework/structures/src/config_3.py rename to docs/griptape-framework/structures/src/drivers_config_3.py index 4d08912f9..0af0423de 100644 --- a/docs/griptape-framework/structures/src/config_3.py +++ b/docs/griptape-framework/structures/src/drivers_config_3.py @@ -2,11 +2,11 @@ import boto3 -from griptape.config import config -from griptape.config.drivers import AmazonBedrockDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import AmazonBedrockDriversConfig from griptape.structures import Agent -config.driver_config = AmazonBedrockDriverConfig( +Defaults.drivers_config = AmazonBedrockDriversConfig( session=boto3.Session( region_name=os.environ["AWS_DEFAULT_REGION"], aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], diff --git a/docs/griptape-framework/structures/src/drivers_config_4.py b/docs/griptape-framework/structures/src/drivers_config_4.py new file mode 100644 index 000000000..f9cfb6d16 --- /dev/null +++ b/docs/griptape-framework/structures/src/drivers_config_4.py @@ -0,0 +1,7 @@ +from griptape.configs import Defaults +from griptape.configs.drivers import GoogleDriversConfig +from griptape.structures import Agent + +Defaults.drivers_config = GoogleDriversConfig() + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/drivers_config_5.py b/docs/griptape-framework/structures/src/drivers_config_5.py new file mode 100644 index 000000000..fb2aa8eee --- /dev/null +++ b/docs/griptape-framework/structures/src/drivers_config_5.py @@ -0,0 +1,7 @@ +from griptape.configs import Defaults +from griptape.configs.drivers import AnthropicDriversConfig +from griptape.structures import Agent + +Defaults.drivers_config = AnthropicDriversConfig() + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/drivers_config_6.py b/docs/griptape-framework/structures/src/drivers_config_6.py new file mode 100644 index 000000000..eaa8e3d71 --- /dev/null +++ b/docs/griptape-framework/structures/src/drivers_config_6.py @@ -0,0 +1,9 @@ +import os + +from griptape.configs import Defaults +from griptape.configs.drivers import CohereDriversConfig +from griptape.structures import Agent + +Defaults.drivers_config = CohereDriversConfig(api_key=os.environ["COHERE_API_KEY"]) + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/config_7.py b/docs/griptape-framework/structures/src/drivers_config_7.py similarity index 66% rename from docs/griptape-framework/structures/src/config_7.py rename to docs/griptape-framework/structures/src/drivers_config_7.py index 3f63d428e..3b1d396ce 100644 --- a/docs/griptape-framework/structures/src/config_7.py +++ b/docs/griptape-framework/structures/src/drivers_config_7.py @@ -1,11 +1,11 @@ import os -from griptape.config import config -from griptape.config.drivers import DriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import DriversConfig from griptape.drivers import AnthropicPromptDriver from griptape.structures import Agent -config.driver_config = DriverConfig( +Defaults.drivers_config = DriversConfig( prompt_driver=AnthropicPromptDriver( model="claude-3-sonnet-20240229", api_key=os.environ["ANTHROPIC_API_KEY"], diff --git a/docs/griptape-framework/structures/src/config_8.py b/docs/griptape-framework/structures/src/drivers_config_8.py similarity index 52% rename from docs/griptape-framework/structures/src/config_8.py rename to docs/griptape-framework/structures/src/drivers_config_8.py index 999911b25..f34a6d0b1 100644 --- a/docs/griptape-framework/structures/src/config_8.py +++ b/docs/griptape-framework/structures/src/drivers_config_8.py @@ -1,8 +1,8 @@ -from griptape.config import config -from griptape.config.drivers import AmazonBedrockDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import AmazonBedrockDriversConfig from griptape.structures import Agent -custom_config = AmazonBedrockDriverConfig() +custom_config = AmazonBedrockDriversConfig() dict_config = custom_config.to_dict() # Use OpenAi for embeddings dict_config["embedding_driver"] = { @@ -11,8 +11,8 @@ "organization": None, "type": "OpenAiEmbeddingDriver", } -custom_config = AmazonBedrockDriverConfig.from_dict(dict_config) +custom_config = AmazonBedrockDriversConfig.from_dict(dict_config) -config.driver_config = custom_config +Defaults.drivers_config = custom_config agent = Agent() diff --git a/docs/griptape-framework/structures/src/logging_config.py b/docs/griptape-framework/structures/src/logging_config.py new file mode 100644 index 000000000..b220e2478 --- /dev/null +++ b/docs/griptape-framework/structures/src/logging_config.py @@ -0,0 +1,14 @@ +import logging + +from griptape.configs import Defaults +from griptape.configs.drivers import OpenAiDriversConfig +from griptape.configs.logging import TruncateLoggingFilter +from griptape.structures import Agent + +Defaults.drivers_config = OpenAiDriversConfig() + +logger = logging.getLogger(Defaults.logging_config.logger_name) +logger.setLevel(logging.ERROR) +logger.addFilter(TruncateLoggingFilter(max_log_length=100)) + +agent = Agent() diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 1ee4538d7..371b3c821 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -1,6 +1,6 @@ from griptape.artifacts import TextArtifact -from griptape.config import config -from griptape.config.drivers import OpenAiDriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import OpenAiDriversConfig from griptape.drivers import ( LocalVectorStoreDriver, OpenAiChatPromptDriver, @@ -11,11 +11,11 @@ from griptape.structures import Agent from griptape.tools import FileManagerTool, QueryTool, WebScraperTool -config.driver_config = OpenAiDriverConfig( +Defaults.drivers_config = OpenAiDriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) -config.driver_config = OpenAiDriverConfig( +Defaults.drivers_config = OpenAiDriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), ) diff --git a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py index 1e82b2dd9..3f6b3b663 100644 --- a/docs/griptape-tools/official-tools/src/rest_api_tool_1.py +++ b/docs/griptape-tools/official-tools/src/rest_api_tool_1.py @@ -1,14 +1,14 @@ from json import dumps -from griptape.config import config -from griptape.config.drivers import DriverConfig +from griptape.configs import Defaults +from griptape.configs.drivers import DriversConfig from griptape.drivers import OpenAiChatPromptDriver from griptape.memory.structure import ConversationMemory from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import RestApiTool -config.driver_config = DriverConfig( +Defaults.drivers_config = DriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1), ) diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py deleted file mode 100644 index 12d951765..000000000 --- a/griptape/config/base_config.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.mixins.serializable_mixin import SerializableMixin - -if TYPE_CHECKING: - from .drivers.base_driver_config import BaseDriverConfig - from .logging.logging_config import LoggingConfig - - -@define(kw_only=True) -class BaseConfig(SerializableMixin, ABC): - logging_config: LoggingConfig = field() - driver_config: BaseDriverConfig = field() diff --git a/griptape/config/config.py b/griptape/config/config.py deleted file mode 100644 index 86bbfc8b7..000000000 --- a/griptape/config/config.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import Factory, define, field - -from .base_config import BaseConfig -from .drivers.openai_driver_config import OpenAiDriverConfig -from .logging.logging_config import LoggingConfig - -if TYPE_CHECKING: - from .drivers.base_driver_config import BaseDriverConfig - - -@define(kw_only=True) -class _Config(BaseConfig): - logging_config: LoggingConfig = field(default=Factory(lambda: LoggingConfig())) - driver_config: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig())) - - -config = _Config() diff --git a/griptape/config/drivers/__init__.py b/griptape/config/drivers/__init__.py deleted file mode 100644 index 9d5f2f510..000000000 --- a/griptape/config/drivers/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from .base_driver_config import BaseDriverConfig -from .driver_config import DriverConfig - -from .openai_driver_config import OpenAiDriverConfig -from .azure_openai_driver_config import AzureOpenAiDriverConfig -from .amazon_bedrock_driver_config import AmazonBedrockDriverConfig -from .anthropic_driver_config import AnthropicDriverConfig -from .google_driver_config import GoogleDriverConfig -from .cohere_driver_config import CohereDriverConfig - -__all__ = [ - "BaseDriverConfig", - "DriverConfig", - "OpenAiDriverConfig", - "AzureOpenAiDriverConfig", - "AmazonBedrockDriverConfig", - "AnthropicDriverConfig", - "GoogleDriverConfig", - "CohereDriverConfig", -] diff --git a/griptape/config/__init__.py b/griptape/configs/__init__.py similarity index 56% rename from griptape/config/__init__.py rename to griptape/configs/__init__.py index 043d152ba..bd12c7836 100644 --- a/griptape/config/__init__.py +++ b/griptape/configs/__init__.py @@ -1,8 +1,8 @@ from .base_config import BaseConfig -from .config import config +from .defaults_config import Defaults __all__ = [ "BaseConfig", - "config", + "Defaults", ] diff --git a/griptape/configs/base_config.py b/griptape/configs/base_config.py new file mode 100644 index 000000000..09d230016 --- /dev/null +++ b/griptape/configs/base_config.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from abc import ABC + +from attrs import define + +from griptape.mixins.serializable_mixin import SerializableMixin + + +@define(kw_only=True) +class BaseConfig(SerializableMixin, ABC): ... diff --git a/griptape/configs/defaults_config.py b/griptape/configs/defaults_config.py new file mode 100644 index 000000000..b81f50cdc --- /dev/null +++ b/griptape/configs/defaults_config.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from attrs import Factory, define, field + +from griptape.mixins.singleton_mixin import SingletonMixin + +from .base_config import BaseConfig +from .drivers.openai_drivers_config import OpenAiDriversConfig +from .logging.logging_config import LoggingConfig + +if TYPE_CHECKING: + from .drivers.base_drivers_config import BaseDriversConfig + + +@define(kw_only=True) +class _DefaultsConfig(BaseConfig, SingletonMixin): + logging_config: LoggingConfig = field(default=Factory(lambda: LoggingConfig())) + drivers_config: BaseDriversConfig = field(default=Factory(lambda: OpenAiDriversConfig())) + + +Defaults = _DefaultsConfig() diff --git a/griptape/configs/drivers/__init__.py b/griptape/configs/drivers/__init__.py new file mode 100644 index 000000000..d407814e8 --- /dev/null +++ b/griptape/configs/drivers/__init__.py @@ -0,0 +1,20 @@ +from .base_drivers_config import BaseDriversConfig +from .drivers_config import DriversConfig + +from .openai_drivers_config import OpenAiDriversConfig +from .azure_openai_drivers_config import AzureOpenAiDriversConfig +from .amazon_bedrock_drivers_config import AmazonBedrockDriversConfig +from .anthropic_drivers_config import AnthropicDriversConfig +from .google_drivers_config import GoogleDriversConfig +from .cohere_drivers_config import CohereDriversConfig + +__all__ = [ + "BaseDriversConfig", + "DriversConfig", + "OpenAiDriversConfig", + "AzureOpenAiDriversConfig", + "AmazonBedrockDriversConfig", + "AnthropicDriversConfig", + "GoogleDriversConfig", + "CohereDriversConfig", +] diff --git a/griptape/config/drivers/amazon_bedrock_driver_config.py b/griptape/configs/drivers/amazon_bedrock_drivers_config.py similarity index 95% rename from griptape/config/drivers/amazon_bedrock_driver_config.py rename to griptape/configs/drivers/amazon_bedrock_drivers_config.py index 22d198167..7a54ac522 100644 --- a/griptape/config/drivers/amazon_bedrock_driver_config.py +++ b/griptape/configs/drivers/amazon_bedrock_drivers_config.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( AmazonBedrockImageGenerationDriver, AmazonBedrockImageQueryDriver, @@ -22,7 +22,7 @@ @define -class AmazonBedrockDriverConfig(DriverConfig): +class AmazonBedrockDriversConfig(DriversConfig): session: boto3.Session = field( default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True, diff --git a/griptape/config/drivers/anthropic_driver_config.py b/griptape/configs/drivers/anthropic_drivers_config.py similarity index 90% rename from griptape/config/drivers/anthropic_driver_config.py rename to griptape/configs/drivers/anthropic_drivers_config.py index b036d85f4..e5a1f2719 100644 --- a/griptape/config/drivers/anthropic_driver_config.py +++ b/griptape/configs/drivers/anthropic_drivers_config.py @@ -1,6 +1,6 @@ from attrs import define -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( AnthropicImageQueryDriver, AnthropicPromptDriver, @@ -11,7 +11,7 @@ @define -class AnthropicDriverConfig(DriverConfig): +class AnthropicDriversConfig(DriversConfig): @lazy_property() def prompt_driver(self) -> AnthropicPromptDriver: return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620") diff --git a/griptape/config/drivers/azure_openai_driver_config.py b/griptape/configs/drivers/azure_openai_drivers_config.py similarity index 96% rename from griptape/config/drivers/azure_openai_driver_config.py rename to griptape/configs/drivers/azure_openai_drivers_config.py index f27c8970c..a29ba3c2f 100644 --- a/griptape/config/drivers/azure_openai_driver_config.py +++ b/griptape/configs/drivers/azure_openai_drivers_config.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, @@ -16,8 +16,8 @@ @define -class AzureOpenAiDriverConfig(DriverConfig): - """Azure OpenAI Driver Configuration. +class AzureOpenAiDriversConfig(DriversConfig): + """Azure OpenAI Drivers Configuration. Attributes: azure_endpoint: The endpoint for the Azure OpenAI instance. diff --git a/griptape/config/drivers/base_driver_config.py b/griptape/configs/drivers/base_drivers_config.py similarity index 98% rename from griptape/config/drivers/base_driver_config.py rename to griptape/configs/drivers/base_drivers_config.py index d8555052f..ec7503478 100644 --- a/griptape/config/drivers/base_driver_config.py +++ b/griptape/configs/drivers/base_drivers_config.py @@ -22,7 +22,7 @@ @define -class BaseDriverConfig(ABC, SerializableMixin): +class BaseDriversConfig(ABC, SerializableMixin): _prompt_driver: BasePromptDriver = field( kw_only=True, default=None, metadata={"serializable": True}, alias="prompt_driver" ) diff --git a/griptape/config/drivers/cohere_driver_config.py b/griptape/configs/drivers/cohere_drivers_config.py similarity index 91% rename from griptape/config/drivers/cohere_driver_config.py rename to griptape/configs/drivers/cohere_drivers_config.py index 25dc833e5..b5d8da8b0 100644 --- a/griptape/config/drivers/cohere_driver_config.py +++ b/griptape/configs/drivers/cohere_drivers_config.py @@ -1,6 +1,6 @@ from attrs import define, field -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( CohereEmbeddingDriver, CoherePromptDriver, @@ -10,7 +10,7 @@ @define -class CohereDriverConfig(DriverConfig): +class CohereDriversConfig(DriversConfig): api_key: str = field(metadata={"serializable": False}, kw_only=True) @lazy_property() diff --git a/griptape/config/drivers/driver_config.py b/griptape/configs/drivers/drivers_config.py similarity index 94% rename from griptape/config/drivers/driver_config.py rename to griptape/configs/drivers/drivers_config.py index 16cb9a535..ed68bcf8c 100644 --- a/griptape/config/drivers/driver_config.py +++ b/griptape/configs/drivers/drivers_config.py @@ -4,7 +4,7 @@ from attrs import define -from griptape.config.drivers import BaseDriverConfig +from griptape.configs.drivers import BaseDriversConfig from griptape.drivers import ( DummyAudioTranscriptionDriver, DummyEmbeddingDriver, @@ -30,7 +30,7 @@ @define -class DriverConfig(BaseDriverConfig): +class DriversConfig(BaseDriversConfig): @lazy_property() def prompt_driver(self) -> BasePromptDriver: return DummyPromptDriver() diff --git a/griptape/config/drivers/google_driver_config.py b/griptape/configs/drivers/google_drivers_config.py similarity index 87% rename from griptape/config/drivers/google_driver_config.py rename to griptape/configs/drivers/google_drivers_config.py index 0ab72e6bb..8d5325235 100644 --- a/griptape/config/drivers/google_driver_config.py +++ b/griptape/configs/drivers/google_drivers_config.py @@ -1,6 +1,6 @@ from attrs import define -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( GoogleEmbeddingDriver, GooglePromptDriver, @@ -10,7 +10,7 @@ @define -class GoogleDriverConfig(DriverConfig): +class GoogleDriversConfig(DriversConfig): @lazy_property() def prompt_driver(self) -> GooglePromptDriver: return GooglePromptDriver(model="gemini-1.5-pro") diff --git a/griptape/config/drivers/openai_driver_config.py b/griptape/configs/drivers/openai_drivers_config.py similarity index 93% rename from griptape/config/drivers/openai_driver_config.py rename to griptape/configs/drivers/openai_drivers_config.py index 49cf60206..205cfb0e1 100644 --- a/griptape/config/drivers/openai_driver_config.py +++ b/griptape/configs/drivers/openai_drivers_config.py @@ -1,6 +1,6 @@ from attrs import define -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers import ( LocalVectorStoreDriver, OpenAiAudioTranscriptionDriver, @@ -14,7 +14,7 @@ @define -class OpenAiDriverConfig(DriverConfig): +class OpenAiDriversConfig(DriversConfig): @lazy_property() def prompt_driver(self) -> OpenAiChatPromptDriver: return OpenAiChatPromptDriver(model="gpt-4o") diff --git a/griptape/config/logging/__init__.py b/griptape/configs/logging/__init__.py similarity index 100% rename from griptape/config/logging/__init__.py rename to griptape/configs/logging/__init__.py diff --git a/griptape/config/logging/logging_config.py b/griptape/configs/logging/logging_config.py similarity index 100% rename from griptape/config/logging/logging_config.py rename to griptape/configs/logging/logging_config.py diff --git a/griptape/config/logging/newline_logging_filter.py b/griptape/configs/logging/newline_logging_filter.py similarity index 100% rename from griptape/config/logging/newline_logging_filter.py rename to griptape/configs/logging/newline_logging_filter.py diff --git a/griptape/config/logging/truncate_logging_filter.py b/griptape/configs/logging/truncate_logging_filter.py similarity index 100% rename from griptape/config/logging/truncate_logging_filter.py rename to griptape/configs/logging/truncate_logging_filter.py diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py index 91e7f4909..ae46c474c 100644 --- a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent, event_bus +from griptape.events import EventBus, FinishAudioTranscriptionEvent, StartAudioTranscriptionEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, A model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: - event_bus.publish_event(StartAudioTranscriptionEvent()) + EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: - event_bus.publish_event(FinishAudioTranscriptionEvent()) + EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_generation/base_image_generation_driver.py b/griptape/drivers/image_generation/base_image_generation_driver.py index 360fba8c9..8dfca5945 100644 --- a/griptape/drivers/image_generation/base_image_generation_driver.py +++ b/griptape/drivers/image_generation/base_image_generation_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import FinishImageGenerationEvent, StartImageGenerationEvent, event_bus +from griptape.events import EventBus, FinishImageGenerationEvent, StartImageGenerationEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,10 +17,10 @@ class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC) model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: - event_bus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) + EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: - event_bus.publish_event(FinishImageGenerationEvent()) + EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index b1050b85c..28c571328 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import FinishImageQueryEvent, StartImageQueryEvent, event_bus +from griptape.events import EventBus, FinishImageQueryEvent, StartImageQueryEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -17,12 +17,12 @@ class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: - event_bus.publish_event( + EventBus.publish_event( StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images]), ) def after_run(self, result: str) -> None: - event_bus.publish_event(FinishImageQueryEvent(result=result)) + EventBus.publish_event(FinishImageQueryEvent(result=result)) def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: for attempt in self.retrying(): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 43b31306c..c07980c9e 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,7 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent, event_bus +from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin if TYPE_CHECKING: @@ -49,10 +49,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: - event_bus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: - event_bus.publish_event( + EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, @@ -126,12 +126,12 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - event_bus.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(CompletionChunkEvent(token=content.text)) elif isinstance(content, ActionCallDeltaMessageContent): if content.tag is not None and content.name is not None and content.path is not None: - event_bus.publish_event(CompletionChunkEvent(token=str(content))) + EventBus.publish_event(CompletionChunkEvent(token=str(content))) elif content.partial_input is not None: - event_bus.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) # Build a complete content from the content deltas return self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py index c74264dc1..cb11cc498 100644 --- a/griptape/drivers/text_to_speech/base_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.events import event_bus +from griptape.events import EventBus from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -19,10 +19,10 @@ class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: - event_bus.publish_event(StartTextToSpeechEvent(prompts=prompts)) + EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: - event_bus.publish_event(FinishTextToSpeechEvent()) + EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py index ee5739d81..4084c8829 100644 --- a/griptape/engines/audio/audio_transcription_engine.py +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -1,14 +1,14 @@ from attrs import Factory, define, field from griptape.artifacts import AudioArtifact, TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.drivers import BaseAudioTranscriptionDriver @define class AudioTranscriptionEngine: audio_transcription_driver: BaseAudioTranscriptionDriver = field( - default=Factory(lambda: config.driver_config.audio_transcription_driver), kw_only=True + default=Factory(lambda: Defaults.drivers_config.audio_transcription_driver), kw_only=True ) def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py index 3036d8bf5..1261ae369 100644 --- a/griptape/engines/audio/text_to_speech_engine.py +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import config +from griptape.configs import Defaults if TYPE_CHECKING: from griptape.artifacts.audio_artifact import AudioArtifact @@ -14,7 +14,7 @@ @define class TextToSpeechEngine: text_to_speech_driver: BaseTextToSpeechDriver = field( - default=Factory(lambda: config.driver_config.text_to_speech_driver), kw_only=True + default=Factory(lambda: Defaults.drivers_config.text_to_speech_driver), kw_only=True ) def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index 0a28b65b3..fb1fab6c4 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Attribute, Factory, define, field from griptape.chunkers import BaseChunker, TextChunker -from griptape.config import config +from griptape.configs import Defaults if TYPE_CHECKING: from griptape.artifacts import ErrorArtifact, ListArtifact @@ -18,7 +18,9 @@ class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field( + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + ) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/engines/image/base_image_generation_engine.py b/griptape/engines/image/base_image_generation_engine.py index 9f72f16be..5fdc60531 100644 --- a/griptape/engines/image/base_image_generation_engine.py +++ b/griptape/engines/image/base_image_generation_engine.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field -from griptape.config import config +from griptape.configs import Defaults if TYPE_CHECKING: from griptape.artifacts import ImageArtifact @@ -16,7 +16,7 @@ @define class BaseImageGenerationEngine(ABC): image_generation_driver: BaseImageGenerationDriver = field( - kw_only=True, default=Factory(lambda: config.driver_config.image_generation_driver) + kw_only=True, default=Factory(lambda: Defaults.drivers_config.image_generation_driver) ) @abstractmethod diff --git a/griptape/engines/image_query/image_query_engine.py b/griptape/engines/image_query/image_query_engine.py index 1b8fce277..348017e64 100644 --- a/griptape/engines/image_query/image_query_engine.py +++ b/griptape/engines/image_query/image_query_engine.py @@ -4,7 +4,7 @@ from attrs import Factory, define, field -from griptape.config import config +from griptape.configs import Defaults if TYPE_CHECKING: from griptape.artifacts import ImageArtifact, TextArtifact @@ -14,7 +14,7 @@ @define class ImageQueryEngine: image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: config.driver_config.image_query_driver), kw_only=True + default=Factory(lambda: Defaults.drivers_config.image_query_driver), kw_only=True ) def run(self, query: str, images: list[ImageArtifact]) -> TextArtifact: diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index d1bcdd3b8..78dfba8f4 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.engines.rag.modules import BaseResponseRagModule from griptape.mixins import RuleMixin from griptape.utils import J2 @@ -18,7 +18,7 @@ @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver)) + prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index c04e9025a..ddff2549c 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape import utils -from griptape.config import config +from griptape.configs import Defaults from griptape.engines.rag.modules import BaseRetrievalRagModule if TYPE_CHECKING: @@ -19,7 +19,7 @@ @define(kw_only=True) class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda: config.driver_config.vector_store_driver) + default=Factory(lambda: Defaults.drivers_config.vector_store_driver) ) query_params: dict[str, Any] = field(factory=dict) process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( diff --git a/griptape/engines/rag/rag_context.py b/griptape/engines/rag/rag_context.py index 1ddfbb1b0..3dbfc6834 100644 --- a/griptape/engines/rag/rag_context.py +++ b/griptape/engines/rag/rag_context.py @@ -18,7 +18,7 @@ class RagContext(SerializableMixin): Attributes: query: Query provided by the user. - module_configs: Dictionary of module configs. First key should be a module name and the second a dictionary of config parameters. + module_configs: Dictionary of module configs. First key should be a module name and the second a dictionary of configs parameters. before_query: An optional list of strings to add before the query in response modules. after_query: An optional list of strings to add after the query in response modules. text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage. diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 04f30ca82..99e133844 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -7,7 +7,7 @@ from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.common import Message, PromptStack -from griptape.config import config +from griptape.configs import Defaults from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -22,7 +22,9 @@ class PromptSummaryEngine(BaseSummaryEngine): max_token_multiplier: float = field(default=0.5, kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field( + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + ) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 431927663..b3e2f3a79 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -22,7 +22,7 @@ from .base_audio_transcription_event import BaseAudioTranscriptionEvent from .start_audio_transcription_event import StartAudioTranscriptionEvent from .finish_audio_transcription_event import FinishAudioTranscriptionEvent -from .event_bus import event_bus +from .event_bus import EventBus __all__ = [ "BaseEvent", @@ -49,5 +49,5 @@ "BaseAudioTranscriptionEvent", "StartAudioTranscriptionEvent", "FinishAudioTranscriptionEvent", - "event_bus", + "EventBus", ] diff --git a/griptape/events/event_bus.py b/griptape/events/event_bus.py index a956f7deb..3ddc325ff 100644 --- a/griptape/events/event_bus.py +++ b/griptape/events/event_bus.py @@ -4,12 +4,14 @@ from attrs import define, field +from griptape.mixins.singleton_mixin import SingletonMixin + if TYPE_CHECKING: from griptape.events import BaseEvent, EventListener @define -class _EventBus: +class _EventBus(SingletonMixin): _event_listeners: list[EventListener] = field(factory=list, kw_only=True, alias="_event_listeners") @property @@ -41,4 +43,4 @@ def clear_event_listeners(self) -> None: self._event_listeners.clear() -event_bus = _EventBus() +EventBus = _EventBus() diff --git a/griptape/exceptions/dummy_exception.py b/griptape/exceptions/dummy_exception.py index 172aeadc6..0020ce547 100644 --- a/griptape/exceptions/dummy_exception.py +++ b/griptape/exceptions/dummy_exception.py @@ -2,7 +2,7 @@ class DummyError(Exception): def __init__(self, dummy_class_name: str, dummy_method_name: str) -> None: message = ( f"You have attempted to use a {dummy_class_name}'s {dummy_method_name} method. " - "This likely originated from using a `DriverConfig` without providing a Driver required for this feature." + "This likely originated from using a `DriversConfig` without providing a Driver required for this feature." ) super().__init__(message) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 86431122a..44c053dc4 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import PromptStack -from griptape.config import config +from griptape.configs import Defaults from griptape.mixins import SerializableMixin if TYPE_CHECKING: @@ -18,7 +18,7 @@ @define class BaseConversationMemory(SerializableMixin, ABC): driver: Optional[BaseConversationMemoryDriver] = field( - default=Factory(lambda: config.driver_config.conversation_memory_driver), kw_only=True + default=Factory(lambda: Defaults.drivers_config.conversation_memory_driver), kw_only=True ) runs: list[Run] = field(factory=list, kw_only=True, metadata={"serializable": True}) structure: Structure = field(init=False) @@ -67,7 +67,7 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = if self.autoprune and hasattr(self, "structure"): should_prune = True - prompt_driver = config.driver_config.prompt_driver + prompt_driver = Defaults.drivers_config.prompt_driver temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 736891d90..055057d34 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.common import Message, PromptStack -from griptape.config import config +from griptape.configs import Defaults from griptape.memory.structure import ConversationMemory from griptape.utils import J2 @@ -18,7 +18,9 @@ @define class SummaryConversationMemory(ConversationMemory): offset: int = field(default=1, kw_only=True, metadata={"serializable": True}) - prompt_driver: BasePromptDriver = field(kw_only=True, default=Factory(lambda: config.driver_config.prompt_driver)) + prompt_driver: BasePromptDriver = field( + kw_only=True, default=Factory(lambda: Defaults.drivers_config.prompt_driver) + ) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 5eb3ab734..1560702fb 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.memory.task.storage import BaseArtifactStorage if TYPE_CHECKING: @@ -15,7 +15,7 @@ @define(kw_only=True) class TextArtifactStorage(BaseArtifactStorage): vector_store_driver: BaseVectorStoreDriver = field( - default=Factory(lambda: config.driver_config.vector_store_driver) + default=Factory(lambda: Defaults.drivers_config.vector_store_driver) ) def can_store(self, artifact: BaseArtifact) -> bool: diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 1bfa95c9a..32e00dd8b 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -5,6 +5,7 @@ from .serializable_mixin import SerializableMixin from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin from .futures_executor_mixin import FuturesExecutorMixin +from .singleton_mixin import SingletonMixin __all__ = [ "ActivityMixin", @@ -14,4 +15,5 @@ "BlobArtifactFileOutputMixin", "SerializableMixin", "FuturesExecutorMixin", + "SingletonMixin", ] diff --git a/griptape/mixins/singleton_mixin.py b/griptape/mixins/singleton_mixin.py new file mode 100644 index 000000000..1d565ceec --- /dev/null +++ b/griptape/mixins/singleton_mixin.py @@ -0,0 +1,10 @@ +from __future__ import annotations + + +class SingletonMixin: + _instance = None + + def __new__(cls, *args, **kwargs) -> SingletonMixin: + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) # noqa: UP008 + return cls._instance diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 2c4edfc7d..24f57395c 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -6,7 +6,7 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import observable -from griptape.config import config +from griptape.configs import Defaults from griptape.memory.structure import Run from griptape.structures import Structure from griptape.tasks import PromptTask, ToolkitTask @@ -23,8 +23,10 @@ class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) - stream: bool = field(default=Factory(lambda: config.driver_config.prompt_driver.stream), kw_only=True) - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) + stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True) + prompt_driver: BasePromptDriver = field( + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + ) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) fail_fast: bool = field(default=False, kw_only=True) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index a18e9d578..0572e289d 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.common import observable -from griptape.events import FinishStructureRunEvent, StartStructureRunEvent, event_bus +from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent from griptape.memory import TaskMemory from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory @@ -137,7 +137,7 @@ def before_run(self, args: Any) -> None: [task.reset() for task in self.tasks] - event_bus.publish_event( + EventBus.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, @@ -149,7 +149,7 @@ def before_run(self, args: Any) -> None: @observable def after_run(self) -> None: - event_bus.publish_event( + EventBus.publish_event( FinishStructureRunEvent( structure_id=self.id, output_task_input=self.output_task.input, diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index da6f214eb..7cdb5d4de 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -11,8 +11,8 @@ from griptape import utils from griptape.artifacts import ActionArtifact, BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact from griptape.common import ToolAction -from griptape.config import config -from griptape.events import FinishActionsSubtaskEvent, StartActionsSubtaskEvent, event_bus +from griptape.configs import Defaults +from griptape.events import EventBus, FinishActionsSubtaskEvent, StartActionsSubtaskEvent from griptape.mixins import ActionsSubtaskOriginMixin from griptape.tasks import BaseTask from griptape.utils import remove_null_values_in_dict_recursively @@ -20,7 +20,7 @@ if TYPE_CHECKING: from griptape.memory import TaskMemory -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define @@ -93,7 +93,7 @@ def attach_to(self, parent_task: BaseTask) -> None: self.output = ErrorArtifact(f"ToolAction input parsing error: {e}", exception=e) def before_run(self) -> None: - event_bus.publish_event( + EventBus.publish_event( StartActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -156,7 +156,7 @@ def execute_action(self, action: ToolAction) -> tuple[str, BaseArtifact]: def after_run(self) -> None: response = self.output.to_text() if isinstance(self.output, BaseArtifact) else str(self.output) - event_bus.publish_event( + EventBus.publish_event( FinishActionsSubtaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index b24ec3f3c..519a1a59a 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -5,11 +5,11 @@ from attrs import define -from griptape.config import config +from griptape.configs import Defaults from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index fc94ed5d4..e39f70fcd 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -7,11 +7,11 @@ from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.config.config import config +from griptape.configs import Defaults from griptape.mixins import RuleMixin from griptape.tasks import BaseTask -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 9c226256b..f0c1f0e7e 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -8,7 +8,7 @@ from attrs import Attribute, define, field -from griptape.config import config +from griptape.configs import Defaults from griptape.loaders import ImageLoader from griptape.mixins import BlobArtifactFileOutputMixin, RuleMixin from griptape.rules import Rule, Ruleset @@ -18,7 +18,7 @@ from griptape.artifacts import MediaArtifact -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py index 52b68e4e0..347dd7e29 100644 --- a/griptape/tasks/base_multi_text_input_task.py +++ b/griptape/tasks/base_multi_text_input_task.py @@ -7,12 +7,12 @@ from attrs import Factory, define, field from griptape.artifacts import ListArtifact, TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index d80767793..535b3a92d 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -9,8 +9,8 @@ from attrs import Factory, define, field from griptape.artifacts import ErrorArtifact -from griptape.config import config -from griptape.events import FinishTaskEvent, StartTaskEvent, event_bus +from griptape.configs import Defaults +from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent from griptape.mixins import FuturesExecutorMixin if TYPE_CHECKING: @@ -18,7 +18,7 @@ from griptape.memory.meta import BaseMetaEntry from griptape.structures import Structure -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define @@ -137,7 +137,7 @@ def is_executing(self) -> bool: def before_run(self) -> None: if self.structure is not None: - event_bus.publish_event( + EventBus.publish_event( StartTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, @@ -149,7 +149,7 @@ def before_run(self) -> None: def after_run(self) -> None: if self.structure is not None: - event_bus.publish_event( + EventBus.publish_event( FinishTaskEvent( task_id=self.id, task_parent_ids=self.parent_ids, diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 0a53b9fcd..dfed85bcf 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -7,12 +7,12 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 719ae77bc..17a73e4cd 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -7,7 +7,7 @@ from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack -from griptape.config import config +from griptape.configs import Defaults from griptape.mixins import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -15,12 +15,14 @@ if TYPE_CHECKING: from griptape.drivers import BasePromptDriver -logger = logging.getLogger(config.logging_config.logger_name) +logger = logging.getLogger(Defaults.logging_config.logger_name) @define class PromptTask(RuleMixin, BaseTask): - prompt_driver: BasePromptDriver = field(default=Factory(lambda: config.driver_config.prompt_driver), kw_only=True) + prompt_driver: BasePromptDriver = field( + default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True + ) generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, diff --git a/griptape/tools/query/tool.py b/griptape/tools/query/tool.py index 70cc9f747..3bc954239 100644 --- a/griptape/tools/query/tool.py +++ b/griptape/tools/query/tool.py @@ -4,7 +4,7 @@ from schema import Literal, Or, Schema from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact, TextArtifact -from griptape.config import config +from griptape.configs import Defaults from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( PromptResponseRagModule, @@ -26,7 +26,7 @@ class QueryTool(BaseTool, RuleMixin): response_stage=ResponseRagStage( response_modules=[ PromptResponseRagModule( - prompt_driver=config.driver_config.prompt_driver, rulesets=self.rulesets + prompt_driver=Defaults.drivers_config.prompt_driver, rulesets=self.rulesets ) ], ), diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index f30e9f1cd..21e045db7 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -38,11 +38,11 @@ def default_output_fn(self, text: str) -> None: print(text) # noqa: T201 def start(self) -> None: - from griptape.config import config + from griptape.configs import Defaults # Hide Griptape's logging output except for errors - old_logger_level = logging.getLogger(config.logging_config.logger_name).getEffectiveLevel() - logging.getLogger(config.logging_config.logger_name).setLevel(self.logger_level) + old_logger_level = logging.getLogger(Defaults.logging_config.logger_name).getEffectiveLevel() + logging.getLogger(Defaults.logging_config.logger_name).setLevel(self.logger_level) if self.intro_text: self.output_fn(self.intro_text) @@ -53,7 +53,7 @@ def start(self) -> None: self.output_fn(self.exiting_text) break - if config.driver_config.prompt_driver.stream: + if Defaults.drivers_config.prompt_driver.stream: self.output_fn(self.processing_text + "\n") stream = Stream(self.structure).run(question) first_chunk = next(stream) @@ -65,4 +65,4 @@ def start(self) -> None: self.output_fn(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}") # Restore the original logger level - logging.getLogger(config.logging_config.logger_name).setLevel(old_logger_level) + logging.getLogger(Defaults.logging_config.logger_name).setLevel(old_logger_level) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 6da58b9e6..8a764e85a 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -7,7 +7,7 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventListener, FinishPromptEvent, FinishStructureRunEvent, event_bus +from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent if TYPE_CHECKING: from collections.abc import Iterator @@ -66,8 +66,8 @@ def event_handler(event: BaseEvent) -> None: handler=event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) - event_bus.add_event_listener(stream_event_listener) + EventBus.add_event_listener(stream_event_listener) self.structure.run(*args) - event_bus.remove_event_listener(stream_event_listener) + EventBus.remove_event_listener(stream_event_listener) diff --git a/mkdocs.yml b/mkdocs.yml index 7e3806264..4207d2171 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -91,7 +91,7 @@ nav: - Task Memory and Off Prompt: "griptape-framework/structures/task-memory.md" - Conversation Memory: "griptape-framework/structures/conversation-memory.md" - Rulesets: "griptape-framework/structures/rulesets.md" - - Config: "griptape-framework/structures/config.md" + - Configs: "griptape-framework/structures/configs.md" - Observability: "griptape-framework/structures/observability.md" - Tools: - Overview: "griptape-framework/tools/index.md" diff --git a/tests/mocks/mock_driver_config.py b/tests/mocks/mock_drivers_config.py similarity index 92% rename from tests/mocks/mock_driver_config.py rename to tests/mocks/mock_drivers_config.py index b038fe920..aa9683dbd 100644 --- a/tests/mocks/mock_driver_config.py +++ b/tests/mocks/mock_drivers_config.py @@ -1,6 +1,6 @@ from attrs import define -from griptape.config.drivers import DriverConfig +from griptape.configs.drivers import DriversConfig from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver from griptape.utils.decorators import lazy_property from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -10,7 +10,7 @@ @define -class MockDriverConfig(DriverConfig): +class MockDriversConfig(DriversConfig): @lazy_property() def prompt_driver(self) -> MockPromptDriver: return MockPromptDriver() diff --git a/tests/unit/config/drivers/test_driver_config.py b/tests/unit/config/drivers/test_driver_config.py deleted file mode 100644 index beb42a018..000000000 --- a/tests/unit/config/drivers/test_driver_config.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest - -from griptape.config.drivers import DriverConfig - - -class TestDriverConfig: - @pytest.fixture() - def config(self): - return DriverConfig() - - def test_to_dict(self, config): - assert config.to_dict() == { - "type": "DriverConfig", - "prompt_driver": { - "type": "DummyPromptDriver", - "temperature": 0.1, - "max_tokens": None, - "stream": False, - "use_native_tools": False, - }, - "conversation_memory_driver": None, - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, - "vector_store_driver": { - "embedding_driver": {"type": "DummyEmbeddingDriver"}, - "type": "DummyVectorStoreDriver", - }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, - "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, - } - - def test_from_dict(self, config): - assert DriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() - - def test_dot_update(self, config): - config.prompt_driver.max_tokens = 10 - - assert config.prompt_driver.max_tokens == 10 - - @pytest.mark.skip_mock_config() - def test_lazy_init(self): - from griptape.config import config - - assert config.driver_config._prompt_driver is None - assert config.driver_config._image_generation_driver is None - assert config.driver_config._image_query_driver is None - assert config.driver_config._embedding_driver is None - assert config.driver_config._vector_store_driver is None - assert config.driver_config._conversation_memory_driver is None - assert config.driver_config._text_to_speech_driver is None - assert config.driver_config._audio_transcription_driver is None - - assert config.driver_config.prompt_driver is not None - assert config.driver_config.image_generation_driver is not None - assert config.driver_config.image_query_driver is not None - assert config.driver_config.embedding_driver is not None - assert config.driver_config.vector_store_driver is not None - assert config.driver_config.conversation_memory_driver is None - assert config.driver_config.text_to_speech_driver is not None - assert config.driver_config.audio_transcription_driver is not None - - assert config.driver_config._prompt_driver is not None - assert config.driver_config._image_generation_driver is not None - assert config.driver_config._image_query_driver is not None - assert config.driver_config._embedding_driver is not None - assert config.driver_config._vector_store_driver is not None - assert config.driver_config._conversation_memory_driver is None - assert config.driver_config._text_to_speech_driver is not None - assert config.driver_config._audio_transcription_driver is not None diff --git a/tests/unit/config/__init__.py b/tests/unit/configs/__init__.py similarity index 100% rename from tests/unit/config/__init__.py rename to tests/unit/configs/__init__.py diff --git a/tests/unit/config/drivers/__init__.py b/tests/unit/configs/drivers/__init__.py similarity index 100% rename from tests/unit/config/drivers/__init__.py rename to tests/unit/configs/drivers/__init__.py diff --git a/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py similarity index 89% rename from tests/unit/config/drivers/test_amazon_bedrock_driver_config.py rename to tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py index e30444332..129fe281f 100644 --- a/tests/unit/config/drivers/test_amazon_bedrock_driver_config.py +++ b/tests/unit/configs/drivers/test_amazon_bedrock_drivers_config.py @@ -1,11 +1,11 @@ import boto3 import pytest -from griptape.config.drivers import AmazonBedrockDriverConfig +from griptape.configs.drivers import AmazonBedrockDriversConfig from tests.utils.aws import mock_aws_credentials -class TestAmazonBedrockDriverConfig: +class TestAmazonBedrockDriversConfig: @pytest.fixture(autouse=True) def _run_before_and_after_tests(self): mock_aws_credentials() @@ -13,11 +13,11 @@ def _run_before_and_after_tests(self): @pytest.fixture() def config(self): mock_aws_credentials() - return AmazonBedrockDriverConfig() + return AmazonBedrockDriversConfig() @pytest.fixture() def config_with_values(self): - return AmazonBedrockDriverConfig( + return AmazonBedrockDriversConfig( session=boto3.Session( aws_access_key_id="testing", aws_secret_access_key="testing", region_name="region-value" ) @@ -62,17 +62,17 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockDriverConfig", + "type": "AmazonBedrockDriversConfig", "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): - assert AmazonBedrockDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AmazonBedrockDriversConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() def test_from_dict_with_values(self, config_with_values): assert ( - AmazonBedrockDriverConfig.from_dict(config_with_values.to_dict()).to_dict() == config_with_values.to_dict() + AmazonBedrockDriversConfig.from_dict(config_with_values.to_dict()).to_dict() == config_with_values.to_dict() ) def test_to_dict_with_values(self, config_with_values): @@ -114,7 +114,7 @@ def test_to_dict_with_values(self, config_with_values): }, "type": "LocalVectorStoreDriver", }, - "type": "AmazonBedrockDriverConfig", + "type": "AmazonBedrockDriversConfig", "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } diff --git a/tests/unit/config/drivers/test_anthropic_driver_config.py b/tests/unit/configs/drivers/test_anthropic_drivers_config.py similarity index 85% rename from tests/unit/config/drivers/test_anthropic_driver_config.py rename to tests/unit/configs/drivers/test_anthropic_drivers_config.py index 770a04b9f..b2335d92a 100644 --- a/tests/unit/config/drivers/test_anthropic_driver_config.py +++ b/tests/unit/configs/drivers/test_anthropic_drivers_config.py @@ -1,9 +1,9 @@ import pytest -from griptape.config.drivers import AnthropicDriverConfig +from griptape.configs.drivers import AnthropicDriversConfig -class TestAnthropicDriverConfig: +class TestAnthropicDriversConfig: @pytest.fixture(autouse=True) def _mock_anthropic(self, mocker): mocker.patch("anthropic.Anthropic") @@ -11,11 +11,11 @@ def _mock_anthropic(self, mocker): @pytest.fixture() def config(self): - return AnthropicDriverConfig() + return AnthropicDriversConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "AnthropicDriverConfig", + "type": "AnthropicDriversConfig", "prompt_driver": { "type": "AnthropicPromptDriver", "temperature": 0.1, @@ -51,4 +51,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert AnthropicDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert AnthropicDriversConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/drivers/test_azure_openai_driver_config.py b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py similarity index 94% rename from tests/unit/config/drivers/test_azure_openai_driver_config.py rename to tests/unit/configs/drivers/test_azure_openai_drivers_config.py index dfc69ce46..5c514c947 100644 --- a/tests/unit/config/drivers/test_azure_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_azure_openai_drivers_config.py @@ -1,16 +1,16 @@ import pytest -from griptape.config.drivers import AzureOpenAiDriverConfig +from griptape.configs.drivers import AzureOpenAiDriversConfig -class TestAzureOpenAiDriverConfig: +class TestAzureOpenAiDriversConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.AzureOpenAI") @pytest.fixture() def config(self): - return AzureOpenAiDriverConfig( + return AzureOpenAiDriversConfig( azure_endpoint="http://localhost:8080", azure_ad_token="test-token", azure_ad_token_provider=lambda: "test-provider", @@ -18,7 +18,7 @@ def config(self): def test_to_dict(self, config): assert config.to_dict() == { - "type": "AzureOpenAiDriverConfig", + "type": "AzureOpenAiDriversConfig", "azure_endpoint": "http://localhost:8080", "prompt_driver": { "type": "AzureOpenAiChatPromptDriver", diff --git a/tests/unit/config/drivers/test_cohere_driver_config.py b/tests/unit/configs/drivers/test_cohere_drivers_config.py similarity index 87% rename from tests/unit/config/drivers/test_cohere_driver_config.py rename to tests/unit/configs/drivers/test_cohere_drivers_config.py index 982733dd6..3c267d73d 100644 --- a/tests/unit/config/drivers/test_cohere_driver_config.py +++ b/tests/unit/configs/drivers/test_cohere_drivers_config.py @@ -1,16 +1,16 @@ import pytest -from griptape.config.drivers import CohereDriverConfig +from griptape.configs.drivers import CohereDriversConfig -class TestCohereDriverConfig: +class TestCohereDriversConfig: @pytest.fixture() def config(self): - return CohereDriverConfig(api_key="api_key") + return CohereDriversConfig(api_key="api_key") def test_to_dict(self, config): assert config.to_dict() == { - "type": "CohereDriverConfig", + "type": "CohereDriversConfig", "image_generation_driver": {"type": "DummyImageGenerationDriver"}, "image_query_driver": {"type": "DummyImageQueryDriver"}, "conversation_memory_driver": None, diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py new file mode 100644 index 000000000..20cc0926c --- /dev/null +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -0,0 +1,70 @@ +import pytest + +from griptape.configs.drivers import DriversConfig + + +class TestDriversConfig: + @pytest.fixture() + def config(self): + return DriversConfig() + + def test_to_dict(self, config): + assert config.to_dict() == { + "type": "DriversConfig", + "prompt_driver": { + "type": "DummyPromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "use_native_tools": False, + }, + "conversation_memory_driver": None, + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "vector_store_driver": { + "embedding_driver": {"type": "DummyEmbeddingDriver"}, + "type": "DummyVectorStoreDriver", + }, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + } + + def test_from_dict(self, config): + assert DriversConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + + def test_dot_update(self, config): + config.prompt_driver.max_tokens = 10 + + assert config.prompt_driver.max_tokens == 10 + + @pytest.mark.skip_mock_config() + def test_lazy_init(self): + from griptape.configs import Defaults + + assert Defaults.drivers_config._prompt_driver is None + assert Defaults.drivers_config._image_generation_driver is None + assert Defaults.drivers_config._image_query_driver is None + assert Defaults.drivers_config._embedding_driver is None + assert Defaults.drivers_config._vector_store_driver is None + assert Defaults.drivers_config._conversation_memory_driver is None + assert Defaults.drivers_config._text_to_speech_driver is None + assert Defaults.drivers_config._audio_transcription_driver is None + + assert Defaults.drivers_config.prompt_driver is not None + assert Defaults.drivers_config.image_generation_driver is not None + assert Defaults.drivers_config.image_query_driver is not None + assert Defaults.drivers_config.embedding_driver is not None + assert Defaults.drivers_config.vector_store_driver is not None + assert Defaults.drivers_config.conversation_memory_driver is None + assert Defaults.drivers_config.text_to_speech_driver is not None + assert Defaults.drivers_config.audio_transcription_driver is not None + + assert Defaults.drivers_config._prompt_driver is not None + assert Defaults.drivers_config._image_generation_driver is not None + assert Defaults.drivers_config._image_query_driver is not None + assert Defaults.drivers_config._embedding_driver is not None + assert Defaults.drivers_config._vector_store_driver is not None + assert Defaults.drivers_config._conversation_memory_driver is None + assert Defaults.drivers_config._text_to_speech_driver is not None + assert Defaults.drivers_config._audio_transcription_driver is not None diff --git a/tests/unit/config/drivers/test_google_driver_config.py b/tests/unit/configs/drivers/test_google_drivers_config.py similarity index 86% rename from tests/unit/config/drivers/test_google_driver_config.py rename to tests/unit/configs/drivers/test_google_drivers_config.py index e16f63eb3..f6df1afef 100644 --- a/tests/unit/config/drivers/test_google_driver_config.py +++ b/tests/unit/configs/drivers/test_google_drivers_config.py @@ -1,20 +1,20 @@ import pytest -from griptape.config.drivers import GoogleDriverConfig +from griptape.configs.drivers import GoogleDriversConfig -class TestGoogleDriverConfig: +class TestGoogleDriversConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("google.generativeai.GenerativeModel") @pytest.fixture() def config(self): - return GoogleDriverConfig() + return GoogleDriversConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "GoogleDriverConfig", + "type": "GoogleDriversConfig", "prompt_driver": { "type": "GooglePromptDriver", "temperature": 0.1, @@ -49,4 +49,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert GoogleDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert GoogleDriversConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/drivers/test_openai_driver_config.py b/tests/unit/configs/drivers/test_openai_driver_config.py similarity index 91% rename from tests/unit/config/drivers/test_openai_driver_config.py rename to tests/unit/configs/drivers/test_openai_driver_config.py index 5c560a7f7..2425b178f 100644 --- a/tests/unit/config/drivers/test_openai_driver_config.py +++ b/tests/unit/configs/drivers/test_openai_driver_config.py @@ -1,20 +1,20 @@ import pytest -from griptape.config.drivers import OpenAiDriverConfig +from griptape.configs.drivers import OpenAiDriversConfig -class TestOpenAiDriverConfig: +class TestOpenAiDriversConfig: @pytest.fixture(autouse=True) def mock_openai(self, mocker): return mocker.patch("openai.OpenAI") @pytest.fixture() def config(self): - return OpenAiDriverConfig() + return OpenAiDriversConfig() def test_to_dict(self, config): assert config.to_dict() == { - "type": "OpenAiDriverConfig", + "type": "OpenAiDriversConfig", "prompt_driver": { "type": "OpenAiChatPromptDriver", "base_url": None, @@ -83,4 +83,4 @@ def test_to_dict(self, config): } def test_from_dict(self, config): - assert OpenAiDriverConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() + assert OpenAiDriversConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() diff --git a/tests/unit/config/logging/__init__.py b/tests/unit/configs/logging/__init__.py similarity index 100% rename from tests/unit/config/logging/__init__.py rename to tests/unit/configs/logging/__init__.py diff --git a/tests/unit/config/logging/test_newline_logging_filter.py b/tests/unit/configs/logging/test_newline_logging_filter.py similarity index 74% rename from tests/unit/config/logging/test_newline_logging_filter.py rename to tests/unit/configs/logging/test_newline_logging_filter.py index 89166dd40..05d022dca 100644 --- a/tests/unit/config/logging/test_newline_logging_filter.py +++ b/tests/unit/configs/logging/test_newline_logging_filter.py @@ -2,15 +2,15 @@ import logging from contextlib import redirect_stdout -from griptape.config import config -from griptape.config.logging import NewlineLoggingFilter +from griptape.configs import Defaults +from griptape.configs.logging import NewlineLoggingFilter from griptape.structures import Agent class TestNewlineLoggingFilter: def test_filter(self): # use the filter in an Agent - logger = logging.getLogger(config.logging_config.logger_name) + logger = logging.getLogger(Defaults.logging_config.logger_name) logger.addFilter(NewlineLoggingFilter(replace_str="$$$")) agent = Agent() # use a context manager to capture the stdout diff --git a/tests/unit/config/logging/test_truncate_logging_filter.py b/tests/unit/configs/logging/test_truncate_logging_filter.py similarity index 75% rename from tests/unit/config/logging/test_truncate_logging_filter.py rename to tests/unit/configs/logging/test_truncate_logging_filter.py index a9387b52b..8aade25f7 100644 --- a/tests/unit/config/logging/test_truncate_logging_filter.py +++ b/tests/unit/configs/logging/test_truncate_logging_filter.py @@ -2,15 +2,15 @@ import logging from contextlib import redirect_stdout -from griptape.config import config -from griptape.config.logging import TruncateLoggingFilter +from griptape.configs import Defaults +from griptape.configs.logging import TruncateLoggingFilter from griptape.structures import Agent class TestTruncateLoggingFilter: def test_filter(self): # use the filter in an Agent - logger = logging.getLogger(config.logging_config.logger_name) + logger = logging.getLogger(Defaults.logging_config.logger_name) logger.addFilter(TruncateLoggingFilter(max_log_length=0)) agent = Agent() # use a context manager to capture the stdout diff --git a/tests/unit/configs/test_defaults_config.py b/tests/unit/configs/test_defaults_config.py new file mode 100644 index 000000000..afb679dd0 --- /dev/null +++ b/tests/unit/configs/test_defaults_config.py @@ -0,0 +1,14 @@ +import pytest + + +class TestDefaultsConfig: + def test_init(self): + from griptape.configs.defaults_config import _DefaultsConfig + + assert _DefaultsConfig() is _DefaultsConfig() + + def test_error_init(self): + from griptape.configs import Defaults + + with pytest.raises(TypeError): + Defaults() # pyright: ignore[reportCallIssue] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8c954c1c8..a70b6b1a7 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,29 +1,29 @@ import pytest -from tests.mocks.mock_driver_config import MockDriverConfig +from tests.mocks.mock_drivers_config import MockDriversConfig @pytest.fixture(autouse=True) def mock_event_bus(): - from griptape.events import event_bus + from griptape.events import EventBus - event_bus.clear_event_listeners() + EventBus.clear_event_listeners() - yield event_bus + yield EventBus - event_bus.clear_event_listeners() + EventBus.clear_event_listeners() @pytest.fixture(autouse=True) def mock_config(request): - from griptape.config import config + from griptape.configs import Defaults - # Some tests we don't want to use the autouse fixture's MockDriverConfig + # Some tests we don't want to use the autouse fixture's MockDriversConfig if "skip_mock_config" in request.keywords: yield return - config.driver_config = MockDriverConfig() + Defaults.drivers_config = MockDriversConfig() - yield config + yield Defaults diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 61ef3aa53..29aecfdf9 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from tests.mocks.mock_audio_transcription_driver import MockAudioTranscriptionDriver @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index ab7b33ae8..96b615a58 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts.image_artifact import ImageArtifact -from griptape.events import event_bus +from griptape.events import EventBus from griptape.events.event_listener import EventListener from tests.mocks.mock_image_generation_driver import MockImageGenerationDriver @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index d8ba6b60f..a77fb268e 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from tests.mocks.mock_image_query_driver import MockImageQueryDriver @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 3e0b0ffc8..3efe85c98 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -11,7 +11,7 @@ class TestBasePromptDriver: def test_run_via_pipeline_retries_success(self, mock_config): - mock_config.driver_config.prompt_driver = MockPromptDriver(max_attempts=2) + mock_config.drivers_config.prompt_driver = MockPromptDriver(max_attempts=2) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -19,7 +19,7 @@ def test_run_via_pipeline_retries_success(self, mock_config): assert isinstance(pipeline.run().output_task.output, TextArtifact) def test_run_via_pipeline_retries_failure(self, mock_config): - mock_config.driver_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) + mock_config.drivers_config.prompt_driver = MockFailingPromptDriver(max_failures=2, max_attempts=1) pipeline = Pipeline() pipeline.add_task(PromptTask("test")) @@ -46,7 +46,7 @@ def test_run_with_stream(self): assert result.value == "mock output" def test_run_with_tools(self, mock_config): - mock_config.driver_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) + mock_config.drivers_config.prompt_driver = MockPromptDriver(max_attempts=1, use_native_tools=True) pipeline = Pipeline() pipeline.add_task(ToolkitTask(tools=[MockTool()])) diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 2090be39c..2dd68e24e 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -20,7 +20,7 @@ def test_run(self): def test_run_with_env(self, mock_config): pipeline = Pipeline() - mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) + mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index 19493aa0f..ab448c7c1 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from tests.mocks.mock_text_to_speech_driver import MockTextToSpeechDriver @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - event_bus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(handler=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index 7eb87036a..cc432dafb 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -1,45 +1,50 @@ from unittest.mock import Mock -from griptape.events import EventListener, event_bus +from griptape.events import EventBus, EventListener from tests.mocks.mock_event import MockEvent class TestEventBus: + def test_init(self): + from griptape.events.event_bus import _EventBus + + assert _EventBus() is _EventBus() + def test_add_event_listeners(self): - event_bus.add_event_listeners([EventListener(), EventListener()]) - assert len(event_bus.event_listeners) == 2 + EventBus.add_event_listeners([EventListener(), EventListener()]) + assert len(EventBus.event_listeners) == 2 def test_remove_event_listeners(self): listeners = [EventListener(), EventListener()] - event_bus.add_event_listeners(listeners) - event_bus.remove_event_listeners(listeners) - assert len(event_bus.event_listeners) == 0 + EventBus.add_event_listeners(listeners) + EventBus.remove_event_listeners(listeners) + assert len(EventBus.event_listeners) == 0 def test_add_event_listener(self): - event_bus.add_event_listener(EventListener()) - event_bus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) + EventBus.add_event_listener(EventListener()) - assert len(event_bus.event_listeners) == 2 + assert len(EventBus.event_listeners) == 2 def test_remove_event_listener(self): listener = EventListener() - event_bus.add_event_listener(listener) - event_bus.remove_event_listener(listener) + EventBus.add_event_listener(listener) + EventBus.remove_event_listener(listener) - assert len(event_bus.event_listeners) == 0 + assert len(EventBus.event_listeners) == 0 def test_remove_unknown_event_listener(self): - event_bus.remove_event_listener(EventListener()) + EventBus.remove_event_listener(EventListener()) def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - event_bus.add_event_listeners([EventListener(handler=mock_handler)]) + EventBus.add_event_listeners([EventListener(handler=mock_handler)]) mock_event = MockEvent() # When - event_bus.publish_event(mock_event) + EventBus.publish_event(mock_event) # Then mock_handler.assert_called_once_with(mock_event) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 0078ebc34..a6d90d4fc 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -4,6 +4,7 @@ from griptape.events import ( CompletionChunkEvent, + EventBus, EventListener, FinishActionsSubtaskEvent, FinishPromptEvent, @@ -13,7 +14,6 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, - event_bus, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -26,7 +26,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.driver_config.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers_config.prompt_driver = MockPromptDriver(stream=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() @@ -39,7 +39,7 @@ def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() - event_bus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() @@ -60,7 +60,7 @@ def test_typed_listeners(self, pipeline, mock_config): finish_structure_run_event_handler = Mock() completion_chunk_handler = Mock() - event_bus.add_event_listeners( + EventBus.add_event_listeners( [ EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), @@ -90,25 +90,25 @@ def test_typed_listeners(self, pipeline, mock_config): completion_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): - event_bus.clear_event_listeners() + EventBus.clear_event_listeners() mock1 = Mock() mock2 = Mock() # duplicate event listeners will only get added once - event_listener_1 = event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_bus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + EventBus.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - event_listener_3 = event_bus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = event_bus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_3 = EventBus.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) + event_listener_4 = EventBus.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) - event_listener_5 = event_bus.add_event_listener(EventListener(mock2)) + event_listener_5 = EventBus.add_event_listener(EventListener(mock2)) - assert len(event_bus.event_listeners) == 4 + assert len(EventBus.event_listeners) == 4 - event_bus.remove_event_listener(event_listener_1) - event_bus.remove_event_listener(event_listener_3) - event_bus.remove_event_listener(event_listener_4) - event_bus.remove_event_listener(event_listener_5) - assert len(event_bus.event_listeners) == 0 + EventBus.remove_event_listener(event_listener_1) + EventBus.remove_event_listener(event_listener_3) + EventBus.remove_event_listener(event_listener_4) + EventBus.remove_event_listener(event_listener_5) + assert len(EventBus.event_listeners) == 0 def test_publish_event(self): mock_event_listener_driver = Mock() diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index b7137524e..3f9ac2344 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -97,7 +97,7 @@ def test_add_to_prompt_stack_autopruing_disabled(self): def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # All memory is pruned. - mock_config.driver_config.prompt_driver = MockPromptDriver( + mock_config.drivers_config.prompt_driver = MockPromptDriver( tokenizer=MockTokenizer(model="foo", max_input_tokens=0) ) agent = Agent() @@ -121,7 +121,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): assert len(prompt_stack.messages) == 3 # No memory is pruned. - mock_config.driver_config.prompt_driver = MockPromptDriver( + mock_config.drivers_config.prompt_driver = MockPromptDriver( tokenizer=MockTokenizer(model="foo", max_input_tokens=1000) ) agent = Agent() @@ -147,7 +147,7 @@ def test_add_to_prompt_stack_autopruning_enabled(self, mock_config): # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens # so that a single memory is pruned. - mock_config.driver_config.prompt_driver = MockPromptDriver( + mock_config.drivers_config.prompt_driver = MockPromptDriver( tokenizer=MockTokenizer(model="foo", max_input_tokens=160) ) agent = Agent() diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index d90f2f8ba..235363bbe 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -222,7 +222,7 @@ def test_task_memory_defaults(self, mock_config): storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.vector_store_driver.embedding_driver == mock_config.driver_config.embedding_driver + assert storage.vector_store_driver.embedding_driver == mock_config.drivers_config.embedding_driver def finished_tasks(self): task = PromptTask("test prompt") diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index e87a99ccb..94a13b938 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import TextArtifact -from griptape.events import event_bus +from griptape.events import EventBus from griptape.events.event_listener import EventListener from griptape.structures import Agent, Workflow from griptape.tasks import ActionsSubtask @@ -14,11 +14,11 @@ class TestBaseTask: @pytest.fixture() def task(self): - event_bus.add_event_listeners([EventListener(handler=Mock())]) + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent = Agent( tools=[MockTool()], ) - event_bus.add_event_listeners([EventListener(handler=Mock())]) + EventBus.add_event_listeners([EventListener(handler=Mock())]) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -115,7 +115,7 @@ def test_children_property_no_structure(self, task): def test_execute_publish_events(self, task): task.execute() - assert event_bus.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].handler.call_count == 2 def test_add_parent(self, task): parent = MockTask("parent foobar", id="parent_foobar") diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index fe434a281..6ea9f5985 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -6,9 +6,9 @@ class TestStructureRunTask: def test_run(self, mock_config): - mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") + mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="agent mock output") agent = Agent() - mock_config.driver_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") + mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index 9ba1df731..dbb76a943 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -168,7 +168,7 @@ class TestToolTask: def agent(self, mock_config): output_dict = {"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "foobar"}}} - mock_config.driver_config.prompt_driver = MockPromptDriver( + mock_config.drivers_config.prompt_driver = MockPromptDriver( mock_output=f"```python foo bar\n{json.dumps(output_dict)}" ) diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 6837fca78..6b238c399 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -171,7 +171,7 @@ def test_init(self): def test_run(self, mock_config): output = """Answer: done""" - mock_config.driver_config.prompt_driver.mock_output = output + mock_config.drivers_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) agent = Agent() @@ -186,7 +186,7 @@ def test_run(self, mock_config): def test_run_max_subtasks(self, mock_config): output = 'Actions: [{"tag": "foo", "name": "Tool1", "path": "test", "input": {"values": {"test": "value"}}}]' - mock_config.driver_config.prompt_driver.mock_output = output + mock_config.drivers_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() @@ -200,7 +200,7 @@ def test_run_max_subtasks(self, mock_config): def test_run_invalid_react_prompt(self, mock_config): output = """foo bar""" - mock_config.driver_config.prompt_driver.mock_output = output + mock_config.drivers_config.prompt_driver.mock_output = output task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) agent = Agent() diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index ff728718a..a8ffb1fff 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -1,7 +1,7 @@ import logging from unittest.mock import patch -from griptape.config import config +from griptape.configs import Defaults from griptape.memory.structure import ConversationMemory from griptape.structures import Agent from griptape.utils import Chat @@ -37,7 +37,7 @@ def test_chat_logger_level(self, mock_input): chat = Chat(agent) - logger = logging.getLogger(config.logging_config.logger_name) + logger = logging.getLogger(Defaults.logging_config.logger_name) logger.setLevel(logging.DEBUG) assert logger.getEffectiveLevel() == logging.DEBUG diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 7de57d85c..9fadf4e36 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -228,9 +228,9 @@ def prompt_driver_id_fn(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: - from griptape.config import config + from griptape.configs import Defaults - config.driver_config.prompt_driver = AzureOpenAiChatPromptDriver( + Defaults.drivers_config.prompt_driver = AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], model="gpt-4o", azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],