Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/naming #1078

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/examples/src/multiple_agent_shared_memory_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
vector_path=MONGODB_VECTOR_PATH,
)

config.drivers = AzureOpenAiDriverConfig(
config.driver_config = AzureOpenAiDriverConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
vector_store=mongo_driver,
embedding=embedding_driver,
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver,
)

loader = Agent(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_video_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.config.drivers import GoogleDriverConfig
from griptape.structures import Agent

config.drivers = GoogleDriverConfig()
config.driver_config = GoogleDriverConfig()

video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4")
while video_file.state.name == "PROCESSING":
Expand Down
12 changes: 6 additions & 6 deletions docs/griptape-framework/drivers/src/embedding_drivers_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from griptape.structures import Agent
from griptape.tools import PromptSummaryTool, WebScraperTool

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o"),
embedding=VoyageAiEmbeddingDriver(),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
embedding_driver=VoyageAiEmbeddingDriver(),
)

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o"),
embedding=VoyageAiEmbeddingDriver(),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
embedding_driver=VoyageAiEmbeddingDriver(),
)

agent = Agent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.rules import Rule
from griptape.structures import Agent

config.drivers = DriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7))
config.driver_config = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7))
event_bus.add_event_listeners(
[
EventListener(
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import OpenAiDriverConfig
from griptape.structures import Agent

config.drivers = OpenAiDriverConfig()
config.driver_config = OpenAiDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from griptape.config.drivers import AzureOpenAiDriverConfig
from griptape.structures import Agent

config.drivers = AzureOpenAiDriverConfig(
config.driver_config = AzureOpenAiDriverConfig(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"]
)

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from griptape.config.drivers import AmazonBedrockDriverConfig
from griptape.structures import Agent

config.drivers = AmazonBedrockDriverConfig(
config.driver_config = AmazonBedrockDriverConfig(
session=boto3.Session(
region_name=os.environ["AWS_DEFAULT_REGION"],
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import GoogleDriverConfig
from griptape.structures import Agent

config.drivers = GoogleDriverConfig()
config.driver_config = GoogleDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import AnthropicDriverConfig
from griptape.structures import Agent

config.drivers = AnthropicDriverConfig()
config.driver_config = AnthropicDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from griptape.config.drivers import CohereDriverConfig
from griptape.structures import Agent

config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])
config.driver_config = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])

agent = Agent()
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/config_7.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from griptape.drivers import AnthropicPromptDriver
from griptape.structures import Agent

config.drivers = DriverConfig(
prompt=AnthropicPromptDriver(
config.driver_config = DriverConfig(
prompt_driver=AnthropicPromptDriver(
model="claude-3-sonnet-20240229",
api_key=os.environ["ANTHROPIC_API_KEY"],
)
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
}
custom_config = AmazonBedrockDriverConfig.from_dict(dict_config)

config.drivers = custom_config
config.driver_config = custom_config

agent = Agent()
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/config_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from griptape.config.logging import TruncateLoggingFilter
from griptape.structures import Agent

config.drivers = OpenAiDriverConfig()
config.driver_config = OpenAiDriverConfig()

logger = logging.getLogger(config.logging.logger_name)
logger = logging.getLogger(config.logging_config.logger_name)
logger.setLevel(logging.ERROR)
logger.addFilter(TruncateLoggingFilter(max_log_length=100))

Expand Down
8 changes: 4 additions & 4 deletions docs/griptape-framework/structures/src/task_memory_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from griptape.structures import Agent
from griptape.tools import FileManagerTool, QueryTool, WebScraperTool

config.drivers = OpenAiDriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4"),
config.driver_config = OpenAiDriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
)

config.drivers = OpenAiDriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4"),
config.driver_config = OpenAiDriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
)

vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-tools/official-tools/src/rest_api_tool_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from griptape.tasks import ToolkitTask
from griptape.tools import RestApiTool

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1),
)

posts_client = RestApiTool(
Expand Down
8 changes: 4 additions & 4 deletions griptape/config/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

@define(kw_only=True)
class BaseConfig(SerializableMixin, ABC):
_logging: Optional[LoggingConfig] = field(alias="logging")
_drivers: Optional[BaseDriverConfig] = field(alias="drivers")
_logging_config: Optional[LoggingConfig] = field(alias="logging")
_driver_config: Optional[BaseDriverConfig] = field(alias="drivers")

def reset(self) -> None:
self._logging = None
self._drivers = None
self._logging_config = None
self._driver_config = None
35 changes: 12 additions & 23 deletions griptape/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from attrs import define, field

from griptape.utils.decorators import lazy_property

from .base_config import BaseConfig
from .drivers.openai_driver_config import OpenAiDriverConfig
from .logging.logging_config import LoggingConfig
Expand All @@ -14,29 +16,16 @@

@define(kw_only=True)
class _Config(BaseConfig):
_logging: Optional[LoggingConfig] = field(default=None, alias="logging")
_drivers: Optional[BaseDriverConfig] = field(default=None, alias="drivers")

@property
def drivers(self) -> BaseDriverConfig:
"""Lazily instantiates the drivers configuration to avoid client errors like missing API key."""
if self._drivers is None:
self._drivers = OpenAiDriverConfig()
return self._drivers

@drivers.setter
def drivers(self, drivers: BaseDriverConfig) -> None:
self._drivers = drivers

@property
def logging(self) -> LoggingConfig:
if self._logging is None:
self._logging = LoggingConfig()
return self._logging

@logging.setter
def logging(self, logging: LoggingConfig) -> None:
self._logging = logging
_logging_config: Optional[LoggingConfig] = field(default=None, alias="logging")
_driver_config: Optional[BaseDriverConfig] = field(default=None, alias="drivers")

@lazy_property()
def driver_config(self) -> BaseDriverConfig:
return OpenAiDriverConfig()

@lazy_property()
def logging_config(self) -> LoggingConfig:
return LoggingConfig()


config = _Config()
10 changes: 5 additions & 5 deletions griptape/config/drivers/amazon_bedrock_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,31 @@ class AmazonBedrockDriverConfig(DriverConfig):
)

@lazy_property()
def prompt(self) -> AmazonBedrockPromptDriver:
def prompt_driver(self) -> AmazonBedrockPromptDriver:
return AmazonBedrockPromptDriver(session=self.session, model="anthropic.claude-3-5-sonnet-20240620-v1:0")

@lazy_property()
def embedding(self) -> AmazonBedrockTitanEmbeddingDriver:
def embedding_driver(self) -> AmazonBedrockTitanEmbeddingDriver:
return AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")

@lazy_property()
def image_generation(self) -> AmazonBedrockImageGenerationDriver:
def image_generation_driver(self) -> AmazonBedrockImageGenerationDriver:
return AmazonBedrockImageGenerationDriver(
session=self.session,
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
)

@lazy_property()
def image_query(self) -> AmazonBedrockImageQueryDriver:
def image_query_driver(self) -> AmazonBedrockImageQueryDriver:
return AmazonBedrockImageQueryDriver(
session=self.session,
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
)

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")
)
8 changes: 4 additions & 4 deletions griptape/config/drivers/anthropic_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
@define
class AnthropicDriverConfig(DriverConfig):
@lazy_property()
def prompt(self) -> AnthropicPromptDriver:
def prompt_driver(self) -> AnthropicPromptDriver:
return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")

@lazy_property()
def embedding(self) -> VoyageAiEmbeddingDriver:
def embedding_driver(self) -> VoyageAiEmbeddingDriver:
return VoyageAiEmbeddingDriver(model="voyage-large-2")

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"))

@lazy_property()
def image_query(self) -> AnthropicImageQueryDriver:
def image_query_driver(self) -> AnthropicImageQueryDriver:
return AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")
10 changes: 5 additions & 5 deletions griptape/config/drivers/azure_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AzureOpenAiDriverConfig(DriverConfig):
api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})

@lazy_property()
def prompt(self) -> AzureOpenAiChatPromptDriver:
def prompt_driver(self) -> AzureOpenAiChatPromptDriver:
return AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
Expand All @@ -51,7 +51,7 @@ def prompt(self) -> AzureOpenAiChatPromptDriver:
)

@lazy_property()
def embedding(self) -> AzureOpenAiEmbeddingDriver:
def embedding_driver(self) -> AzureOpenAiEmbeddingDriver:
return AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
Expand All @@ -61,7 +61,7 @@ def embedding(self) -> AzureOpenAiEmbeddingDriver:
)

@lazy_property()
def image_generation(self) -> AzureOpenAiImageGenerationDriver:
def image_generation_driver(self) -> AzureOpenAiImageGenerationDriver:
return AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_endpoint=self.azure_endpoint,
Expand All @@ -72,7 +72,7 @@ def image_generation(self) -> AzureOpenAiImageGenerationDriver:
)

@lazy_property()
def image_query(self) -> AzureOpenAiImageQueryDriver:
def image_query_driver(self) -> AzureOpenAiImageQueryDriver:
return AzureOpenAiImageQueryDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
Expand All @@ -82,7 +82,7 @@ def image_query(self) -> AzureOpenAiImageQueryDriver:
)

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
Expand Down
Loading
Loading