Skip to content

Commit

Permalink
Lazy load driver config fields
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 19, 2024
1 parent a931176 commit f5ae860
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 303 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `JsonArtifact` for handling de/seralization of values.
- `Chat.logger_level` for setting what the `Chat` utility sets the logger level to.
- `FuturesExecutorMixin` to DRY up and optimize concurrent code across multiple classes.
- `griptape.utils.decorators.lazy_property` for creating lazy properties.

### Changed
- **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`.
Expand Down
83 changes: 30 additions & 53 deletions griptape/config/drivers/amazon_bedrock_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
AmazonBedrockImageQueryDriver,
AmazonBedrockPromptDriver,
AmazonBedrockTitanEmbeddingDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BaseImageQueryDriver,
BasePromptDriver,
BaseVectorStoreDriver,
BedrockClaudeImageQueryModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -33,51 +29,32 @@ class AmazonBedrockDriverConfig(DriverConfig):
metadata={"serializable": False},
)

prompt: BasePromptDriver = field(
default=Factory(
lambda self: AmazonBedrockPromptDriver(
session=self.session,
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
embedding: BaseEmbeddingDriver = field(
default=Factory(
lambda self: AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1"),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
image_generation: BaseImageGenerationDriver = field(
default=Factory(
lambda self: AmazonBedrockImageGenerationDriver(
session=self.session,
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
image_query: BaseImageQueryDriver = field(
default=Factory(
lambda self: AmazonBedrockImageQueryDriver(
session=self.session,
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
vector_store: BaseVectorStoreDriver = field(
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True),
kw_only=True,
metadata={"serializable": True},
)
@lazy_property
def prompt(self) -> AmazonBedrockPromptDriver:
return AmazonBedrockPromptDriver(session=self.session, model="anthropic.claude-3-5-sonnet-20240620-v1:0")

@lazy_property
def embedding(self) -> AmazonBedrockTitanEmbeddingDriver:
return AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")

@lazy_property
def image_generation(self) -> AmazonBedrockImageGenerationDriver:
return AmazonBedrockImageGenerationDriver(
session=self.session,
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
)

@lazy_property
def image_query(self) -> AmazonBedrockImageQueryDriver:
return AmazonBedrockImageQueryDriver(
session=self.session,
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
)

@lazy_property
def vector_store(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")
)
45 changes: 18 additions & 27 deletions griptape/config/drivers/anthropic_driver_config.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,30 @@
from attrs import Factory, define, field
from attrs import define

from griptape.config.drivers import DriverConfig
from griptape.drivers import (
AnthropicImageQueryDriver,
AnthropicPromptDriver,
BaseEmbeddingDriver,
BaseImageQueryDriver,
BasePromptDriver,
BaseVectorStoreDriver,
LocalVectorStoreDriver,
VoyageAiEmbeddingDriver,
)
from griptape.utils.decorators import lazy_property


# TODO move class fields to lazy properties
@define
class AnthropicDriverConfig(DriverConfig):
prompt: BasePromptDriver = field(
default=Factory(lambda: AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")),
metadata={"serializable": True},
kw_only=True,
)
embedding: BaseEmbeddingDriver = field(
default=Factory(lambda: VoyageAiEmbeddingDriver(model="voyage-large-2")),
metadata={"serializable": True},
kw_only=True,
)
vector_store: BaseVectorStoreDriver = field(
default=Factory(
lambda: LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")),
),
kw_only=True,
metadata={"serializable": True},
)
image_query: BaseImageQueryDriver = field(
default=Factory(lambda: AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")),
kw_only=True,
metadata={"serializable": True},
)
@lazy_property
def prompt(self) -> AnthropicPromptDriver:
return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")

@lazy_property
def embedding(self) -> VoyageAiEmbeddingDriver:
return VoyageAiEmbeddingDriver(model="voyage-large-2")

@lazy_property
def vector_store(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"))

@lazy_property
def image_query(self) -> AnthropicImageQueryDriver:
return AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")
113 changes: 50 additions & 63 deletions griptape/config/drivers/azure_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@

from typing import Callable, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape.config.drivers import DriverConfig
from griptape.drivers import (
AzureOpenAiChatPromptDriver,
AzureOpenAiEmbeddingDriver,
AzureOpenAiImageGenerationDriver,
AzureOpenAiImageQueryDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BaseImageQueryDriver,
BasePromptDriver,
BaseVectorStoreDriver,
LocalVectorStoreDriver,
)
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -43,65 +39,56 @@ class AzureOpenAiDriverConfig(DriverConfig):
metadata={"serializable": False},
)
api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
prompt: BasePromptDriver = field(
default=Factory(
lambda self: AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
image_generation: BaseImageGenerationDriver = field(
default=Factory(
lambda self: AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
image_size="512x512",
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
image_query: BaseImageQueryDriver = field(
default=Factory(
lambda self: AzureOpenAiImageQueryDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
embedding: BaseEmbeddingDriver = field(
default=Factory(
lambda self: AzureOpenAiEmbeddingDriver(

@lazy_property
def prompt(self) -> AzureOpenAiChatPromptDriver:
return AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)

@lazy_property
def embedding(self) -> AzureOpenAiEmbeddingDriver:
return AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)

@lazy_property
def image_generation(self) -> AzureOpenAiImageGenerationDriver:
return AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
image_size="512x512",
)

@lazy_property
def image_query(self) -> AzureOpenAiImageQueryDriver:
return AzureOpenAiImageQueryDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)

@lazy_property
def vector_store(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
vector_store: BaseVectorStoreDriver = field(
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding), takes_self=True),
metadata={"serializable": True},
kw_only=True,
)
)
)
69 changes: 56 additions & 13 deletions griptape/config/drivers/base_driver_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.mixins import SerializableMixin
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from griptape.drivers import (
Expand All @@ -22,15 +23,57 @@

@define
class BaseDriverConfig(ABC, SerializableMixin):
prompt: BasePromptDriver = field(kw_only=True, metadata={"serializable": True})
image_generation: BaseImageGenerationDriver = field(kw_only=True, metadata={"serializable": True})
image_query: BaseImageQueryDriver = field(kw_only=True, metadata={"serializable": True})
embedding: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True})
vector_store: BaseVectorStoreDriver = field(kw_only=True, metadata={"serializable": True})
conversation_memory: Optional[BaseConversationMemoryDriver] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
)
text_to_speech: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True})
audio_transcription: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True})
_prompt: BasePromptDriver = field(kw_only=True, default=None, metadata={"serializable": True}, alias="prompt")
_image_generation: BaseImageGenerationDriver = field(
kw_only=True, default=None, metadata={"serializable": True}, alias="image_generation"
)
_image_query: BaseImageQueryDriver = field(
kw_only=True, default=None, metadata={"serializable": True}, alias="image_query"
)
_embedding: BaseEmbeddingDriver = field(
kw_only=True, default=None, metadata={"serializable": True}, alias="embedding"
)
_vector_store: BaseVectorStoreDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="vector_store"
)
_conversation_memory: Optional[BaseConversationMemoryDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="conversation_memory"
)
_text_to_speech: BaseTextToSpeechDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="text_to_speech"
)
_audio_transcription: BaseAudioTranscriptionDriver = field(
default=None, kw_only=True, metadata={"serializable": True}, alias="audio_transcription"
)

@lazy_property
@abstractmethod
def prompt(self) -> BasePromptDriver: ...

@lazy_property
@abstractmethod
def image_generation(self) -> BaseImageGenerationDriver: ...

@lazy_property
@abstractmethod
def image_query(self) -> BaseImageQueryDriver: ...

@lazy_property
@abstractmethod
def embedding(self) -> BaseEmbeddingDriver: ...

@lazy_property
@abstractmethod
def vector_store(self) -> BaseVectorStoreDriver: ...

@lazy_property
@abstractmethod
def conversation_memory(self) -> Optional[BaseConversationMemoryDriver]: ...

@lazy_property
@abstractmethod
def text_to_speech(self) -> BaseTextToSpeechDriver: ...

@lazy_property
@abstractmethod
def audio_transcription(self) -> BaseAudioTranscriptionDriver: ...
Loading

0 comments on commit f5ae860

Please sign in to comment.