diff --git a/CHANGELOG.md b/CHANGELOG.md index d3d8ddcf43..820852e71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Changed +- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver` to `client`. +- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver` to `client`. +- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageGenerationDriver` to `client`. +- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageQueryDriver` to `client`. +- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockPromptDriver` to `client`. +- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver` to `client`. +- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver` to `client`. +- **BREAKING**: Renamed parameter `sqs_client` on `AmazonSqsEventListenerDriver` to `client`. +- **BREAKING**: Renamed parameter `iotdata_client` on `AwsIotCoreEventListenerDriver` to `client`. +- **BREAKING**: Renamed parameter `s3_client` on `AmazonS3FileManagerDriver` to `client`. +- **BREAKING**: Renamed parameter `s3_client` on `AwsS3Tool` to `client`. +- **BREAKING**: Renamed parameter `pusher_client` on `PusherEventListenerDriver` to `client`. +- **BREAKING**: Renamed parameter `mq` on `MarqoVectorStoreDriver` to `client`. +- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`. +- **BREAKING**: Renamed parameter `model_client` on `GoogleTokenizer` to `client`. +- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `text_generation_pipeline`. +- **BREAKING**: Renamed parameter `engine` on `PgVectorVectorStoreDriver` to `sqlalchemy_engine`. +- Several places where API clients are initialized are now lazy loaded. + + ## [0.32.0] - 2024-09-17 ### Added diff --git a/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py index 312fa83187..f81031897a 100644 --- a/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py @@ -4,10 +4,11 @@ from typing import Optional import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import AudioArtifact, TextArtifact from griptape.drivers import BaseAudioTranscriptionDriver +from griptape.utils.decorators import lazy_property @define @@ -17,12 +18,11 @@ class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: additional_params = {} diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 4e4f4aa31d..1799be705f 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -8,6 +8,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -26,7 +27,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): `search_query` when querying your vector DB to find relevant documents. session: Optionally provide custom `boto3.Session`. tokenizer: Optionally provide custom `BedrockCohereTokenizer`. - bedrock_client: Optionally provide custom `bedrock-runtime` client. + client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "cohere.embed-english-v3" @@ -38,15 +39,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"input_type": self.input_type, "texts": [chunk]} - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="*/*", diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index 5900d7d86b..a52e0840da 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -8,6 +8,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -23,7 +24,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): model: Embedding model name. Defaults to DEFAULT_MODEL. tokenizer: Optionally provide custom `BedrockTitanTokenizer`. session: Optionally provide custom `boto3.Session`. - bedrock_client: Optionally provide custom `bedrock-runtime` client. + client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "amazon.titan-embed-text-v1" @@ -34,15 +35,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"inputText": chunk} - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="application/json", diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py index c4feb8a1df..5032b0caf8 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -15,18 +16,19 @@ @define class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), - kw_only=True, - ) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("sagemaker-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} - endpoint_response = self.sagemaker_client.invoke_endpoint( + endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), diff --git a/griptape/drivers/embedding/azure_openai_embedding_driver.py b/griptape/drivers/embedding/azure_openai_embedding_driver.py index c1e601aef4..89fb620e0a 100644 --- a/griptape/drivers/embedding/azure_openai_embedding_driver.py +++ b/griptape/drivers/embedding/azure_openai_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import OpenAiEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer +from griptape.utils.decorators import lazy_property @define @@ -40,17 +41,16 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver): default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/embedding/cohere_embedding_driver.py b/griptape/drivers/embedding/cohere_embedding_driver.py index 365dc972ed..34526ed753 100644 --- a/griptape/drivers/embedding/cohere_embedding_driver.py +++ b/griptape/drivers/embedding/cohere_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import CohereTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from cohere import Client @@ -27,16 +28,16 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver): DEFAULT_MODEL = "models/embedding-001" api_key: str = field(kw_only=True, metadata={"serializable": False}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), - kw_only=True, - ) + input_type: str = field(kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, metadata={"serializable": False}) tokenizer: CohereTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True, ) - input_type: str = field(kw_only=True, metadata={"serializable": True}) + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("cohere").Client(self.api_key) def try_embed_chunk(self, chunk: str) -> list[float]: result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index c1be2ec96e..069a9044b6 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from huggingface_hub import InferenceClient @@ -22,16 +23,14 @@ class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver): """ api_token: str = field(kw_only=True, metadata={"serializable": True}) - client: InferenceClient = field( - default=Factory( - lambda self: import_optional_dependency("huggingface_hub").InferenceClient( - model=self.model, - token=self.api_token, - ), - takes_self=True, - ), - kw_only=True, - ) + _client: InferenceClient = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> InferenceClient: + return import_optional_dependency("huggingface_hub").InferenceClient( + model=self.model, + token=self.api_token, + ) def try_embed_chunk(self, chunk: str) -> list[float]: response = self.client.feature_extraction(chunk) diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py index c5c30d5af0..f6409ec4b3 100644 --- a/griptape/drivers/embedding/ollama_embedding_driver.py +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from ollama import Client @@ -23,10 +24,11 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver): model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), - kw_only=True, - ) + _client: Client = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("ollama").Client(host=self.host) def try_embed_chunk(self, chunk: str) -> list[float]: return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/griptape/drivers/embedding/openai_embedding_driver.py b/griptape/drivers/embedding/openai_embedding_driver.py index 0995fba68b..b0b799790e 100644 --- a/griptape/drivers/embedding/openai_embedding_driver.py +++ b/griptape/drivers/embedding/openai_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer +from griptape.utils.decorators import lazy_property @define @@ -33,16 +34,15 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) tokenizer: OpenAiTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_embed_chunk(self, chunk: str) -> list[float]: # Address a performance issue in older ada models diff --git a/griptape/drivers/embedding/voyageai_embedding_driver.py b/griptape/drivers/embedding/voyageai_embedding_driver.py index c5e418ed18..e08629afa1 100644 --- a/griptape/drivers/embedding/voyageai_embedding_driver.py +++ b/griptape/drivers/embedding/voyageai_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import VoyageAiTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property @define @@ -25,17 +26,16 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver): model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key), - takes_self=True, - ), - ) tokenizer: VoyageAiTokenizer = field( default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True, ) input_type: str = field(default="document", kw_only=True, metadata={"serializable": True}) + _client: Any = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return import_optional_dependency("voyageai").Client(api_key=self.api_key) def try_embed_chunk(self, chunk: str) -> list[float]: return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0] diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 4c632cb01e..e2019db5c5 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -7,6 +7,7 @@ from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -16,10 +17,14 @@ class AmazonSqsEventListenerDriver(BaseEventListenerDriver): queue_url: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True)) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("sqs") def try_publish_event_payload(self, event_payload: dict) -> None: - self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) + self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: entries = [ @@ -27,4 +32,4 @@ def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> No for event_payload in event_payload_batch ] - self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) + self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index 3b014aed4c..25ccd5bed8 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -7,6 +7,7 @@ from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -17,10 +18,14 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): iot_endpoint: str = field(kw_only=True) topic: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True)) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("iot-data") def try_publish_event_payload(self, event_payload: dict) -> None: - self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload)) + self.client.publish(topic=self.topic, payload=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: - self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) + self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) diff --git a/griptape/drivers/event_listener/pusher_event_listener_driver.py b/griptape/drivers/event_listener/pusher_event_listener_driver.py index ce9a4fb34d..8d8119ed47 100644 --- a/griptape/drivers/event_listener/pusher_event_listener_driver.py +++ b/griptape/drivers/event_listener/pusher_event_listener_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from pusher import Pusher @@ -13,25 +14,24 @@ @define class PusherEventListenerDriver(BaseEventListenerDriver): - app_id: str = field(kw_only=True) - key: str = field(kw_only=True) - secret: str = field(kw_only=True) - cluster: str = field(kw_only=True) - channel: str = field(kw_only=True) - event_name: str = field(kw_only=True) - pusher_client: Pusher = field( - default=Factory( - lambda self: import_optional_dependency("pusher").Pusher( - app_id=self.app_id, - key=self.key, - secret=self.secret, - cluster=self.cluster, - ssl=True, - ), - takes_self=True, - ), - kw_only=True, - ) + app_id: str = field(kw_only=True, metadata={"serializable": True}) + key: str = field(kw_only=True, metadata={"serializable": True}) + secret: str = field(kw_only=True, metadata={"serializable": False}) + cluster: str = field(kw_only=True, metadata={"serializable": True}) + channel: str = field(kw_only=True, metadata={"serializable": True}) + event_name: str = field(kw_only=True, metadata={"serializable": True}) + ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Pusher = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Pusher: + return import_optional_dependency("pusher").Pusher( + app_id=self.app_id, + key=self.key, + secret=self.secret, + cluster=self.cluster, + ssl=self.ssl, + ) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: data = [ @@ -39,7 +39,7 @@ def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> No for event_payload in event_payload_batch ] - self.pusher_client.trigger_batch(data) + self.client.trigger_batch(data) def try_publish_event_payload(self, event_payload: dict) -> None: - self.pusher_client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload) + self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload) diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index 20e432c0b0..7e27d0e4de 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -4,6 +4,7 @@ from attrs import Attribute, Factory, define, field +from griptape.utils.decorators import lazy_property from griptape.utils.import_utils import import_optional_dependency from .base_file_manager_driver import BaseFileManagerDriver @@ -21,13 +22,17 @@ class AmazonS3FileManagerDriver(BaseFileManagerDriver): bucket: The name of the S3 bucket. workdir: The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory. - s3_client: The S3 client to use for S3 operations. + client: The S3 client to use for S3 operations. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) bucket: str = field(kw_only=True) workdir: str = field(default="/", kw_only=True) - s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("s3") @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: @@ -51,7 +56,7 @@ def try_load_file(self, path: str) -> bytes: raise IsADirectoryError try: - response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key) + response = self.client.get_object(Bucket=self.bucket, Key=full_key) return response["Body"].read() except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: @@ -62,7 +67,7 @@ def try_save_file(self, path: str, value: bytes) -> None: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError - self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value) + self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value) def _to_full_key(self, path: str) -> str: path = path.lstrip("/") @@ -90,7 +95,7 @@ def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: if max_items is not None: pagination_config["MaxItems"] = max_items - paginator = self.s3_client.get_paginator("list_objects_v2") + paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate( Bucket=self.bucket, Prefix=full_key, @@ -116,7 +121,7 @@ def _is_a_directory(self, full_key: str) -> bool: return True try: - self.s3_client.head_object(Bucket=self.bucket, Key=full_key) + self.client.head_object(Bucket=self.bucket, Key=full_key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: return len(self._list_files_and_dirs(full_key, max_items=1)) > 0 diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 4db302f6f6..aa82d64e9b 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -8,6 +8,7 @@ from griptape.artifacts import ImageArtifact from griptape.drivers import BaseMultiModelImageGenerationDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -20,19 +21,21 @@ class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver): Attributes: model: Bedrock model ID. session: boto3 session. - bedrock_client: Bedrock runtime client. + client: Bedrock runtime client. image_width: Width of output images. Defaults to 512 and must be a multiple of 64. image_height: Height of output images. Defaults to 512 and must be a multiple of 64. seed: Optionally provide a consistent seed to generation requests, increasing consistency in output. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True), - ) image_width: int = field(default=512, kw_only=True, metadata={"serializable": True}) image_height: int = field(default=512, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: request = self.image_generation_model_driver.text_to_image_request_parameters( @@ -127,7 +130,7 @@ def try_image_outpainting( ) def _make_request(self, request: dict) -> bytes: - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(request), modelId=self.model, accept="application/json", diff --git a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py index 85facda4c1..f5d936ce5c 100644 --- a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiImageGenerationDriver +from griptape.utils.decorators import lazy_property @define @@ -34,17 +35,16 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index bf77ac300b..ec8129e892 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, Literal, Optional, Union, cast import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationDriver +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from openai.types.images_response import ImagesResponse @@ -38,12 +39,6 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) quality: Union[Literal["standard"], Literal["hd"]] = field( default="standard", @@ -58,6 +53,11 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver): Literal["1792x1024"], ] = field(default="1024x1024", kw_only=True, metadata={"serializable": True}) response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True, metadata={"serializable": True}) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: prompt = ", ".join(prompts) diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py index 46406d9726..044a41c818 100644 --- a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py +++ b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseMultiModelImageQueryDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 @@ -17,15 +18,16 @@ @define class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( modelId=self.model, contentType="application/json", accept="application/json", diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py index a50685724f..808e2ca930 100644 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ b/griptape/drivers/image_query/anthropic_image_query_driver.py @@ -2,11 +2,12 @@ from typing import Any, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property @define @@ -21,13 +22,11 @@ class AnthropicImageQueryDriver(BaseImageQueryDriver): api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) + _client: Any = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: if self.max_tokens is None: diff --git a/griptape/drivers/image_query/azure_openai_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py index 04492e4719..c3389b138a 100644 --- a/griptape/drivers/image_query/azure_openai_image_query_driver.py +++ b/griptape/drivers/image_query/azure_openai_image_query_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver +from griptape.utils.decorators import lazy_property @define @@ -34,17 +35,16 @@ class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/image_query/openai_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py index 6399efa95d..f0ef9e148f 100644 --- a/griptape/drivers/image_query/openai_image_query_driver.py +++ b/griptape/drivers/image_query/openai_image_query_driver.py @@ -3,7 +3,7 @@ from typing import Literal, Optional import openai -from attrs import Factory, define, field +from attrs import define, field from openai.types.chat import ( ChatCompletionContentPartImageParam, ChatCompletionContentPartParam, @@ -13,6 +13,7 @@ from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers.image_query.base_image_query_driver import BaseImageQueryDriver +from griptape.utils.decorators import lazy_property @define @@ -24,12 +25,11 @@ class OpenAiImageQueryDriver(BaseImageQueryDriver): api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: message_parts: list[ChatCompletionContentPartParam] = [ diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index bc339f6184..be34d2a8cf 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -31,6 +31,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -44,10 +45,6 @@ @define class AmazonBedrockPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), @@ -55,10 +52,15 @@ class AmazonBedrockPromptDriver(BasePromptDriver): ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") @observable def try_run(self, prompt_stack: PromptStack) -> Message: - response = self.bedrock_client.converse(**self._base_params(prompt_stack)) + response = self.client.converse(**self._base_params(prompt_stack)) usage = response["usage"] output_message = response["output"]["message"] @@ -71,7 +73,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) + response = self.client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") if stream is not None: diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index d7a2f5b0b2..2dcf55307c 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -10,6 +10,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -22,10 +23,6 @@ @define class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), - kw_only=True, - ) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) @@ -38,6 +35,11 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("sagemaker-runtime") @stream.validator # pyright: ignore[reportAttributeAccessIssue] def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001 @@ -51,7 +53,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: "parameters": {**self._base_params(prompt_stack)}, } - response = self.sagemaker_client.invoke_endpoint( + response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload), diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index ae50bc59e5..8c944b2cce 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -32,6 +32,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -54,13 +55,6 @@ class AnthropicPromptDriver(BasePromptDriver): api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory( - lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True, @@ -70,6 +64,11 @@ class AnthropicPromptDriver(BasePromptDriver): tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index b08b51b69c..02cfb69bb8 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiChatPromptDriver +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from griptape.common import PromptStack @@ -37,20 +38,19 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): metadata={"serializable": False}, ) api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index ff1a8b4824..b31c78ea35 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -23,6 +23,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, CohereTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -45,14 +46,16 @@ class CoherePromptDriver(BasePromptDriver): api_key: str = field(metadata={"serializable": False}) model: str = field(metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), - ) + force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), ) - force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("cohere").Client(self.api_key) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 6b18f60416..4afdad5c6c 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -26,6 +26,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, GoogleTokenizer from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -44,17 +45,13 @@ class GooglePromptDriver(BasePromptDriver): Attributes: api_key: Google API key. model: Google model name. - model_client: Custom `GenerativeModel` client. + client: Custom `GenerativeModel` client. top_p: Optional value for top_p. top_k: Optional value for top_k. """ api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - model_client: GenerativeModel = field( - default=Factory(lambda self: self._default_model_client(), takes_self=True), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, @@ -63,11 +60,19 @@ class GooglePromptDriver(BasePromptDriver): top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) + _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> GenerativeModel: + genai = import_optional_dependency("google.generativeai") + genai.configure(api_key=self.api_key) + + return genai.GenerativeModel(self.model) @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) - response: GenerateContentResponse = self.model_client.generate_content( + response: GenerateContentResponse = self.client.generate_content( messages, **self._base_params(prompt_stack), ) @@ -86,7 +91,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) - response: GenerateContentResponse = self.model_client.generate_content( + response: GenerateContentResponse = self.client.generate_content( messages, **self._base_params(prompt_stack), stream=True, @@ -119,7 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages if system_messages: - self.model_client._system_instruction = types.ContentDict( + self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) @@ -146,12 +151,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: ), } - def _default_model_client(self) -> GenerativeModel: - genai = import_optional_dependency("google.generativeai") - genai.configure(api_key=self.api_key) - - return genai.GenerativeModel(self.model) - def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 657b5747c8..68267f755b 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -8,6 +8,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -32,16 +33,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: InferenceClient = field( - default=Factory( - lambda self: import_optional_dependency("huggingface_hub").InferenceClient( - model=self.model, - token=self.api_token, - ), - takes_self=True, - ), - kw_only=True, - ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -49,6 +40,14 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> InferenceClient: + return import_optional_dependency("huggingface_hub").InferenceClient( + model=self.model, + token=self.api_token, + ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 128167f527..273db870b2 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -9,6 +9,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -35,23 +36,24 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) - pipe: TextGenerationPipeline = field( - default=Factory( - lambda self: import_optional_dependency("transformers").pipeline( - "text-generation", - model=self.model, - max_new_tokens=self.max_tokens, - tokenizer=self.tokenizer.tokenizer, - ), - takes_self=True, - ), + _text_generation_pipeline: TextGenerationPipeline = field( + default=None, kw_only=True, alias="text_generation_pipeline", metadata={"serializable": False} ) + @lazy_property() + def text_generation_pipeline(self) -> TextGenerationPipeline: + return import_optional_dependency("transformers").pipeline( + "text-generation", + model=self.model, + max_new_tokens=self.max_tokens, + tokenizer=self.tokenizer.tokenizer, + ) + @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) - result = self.pipe( + result = self.text_generation_pipeline( messages, max_new_tokens=self.max_tokens, temperature=self.temperature, diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 70d4ce89af..5f9e32e2f6 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -22,6 +22,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import SimpleTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from ollama import Client @@ -40,10 +41,6 @@ class OllamaPromptDriver(BasePromptDriver): model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory( lambda self: SimpleTokenizer( @@ -67,6 +64,11 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("ollama").Client(host=self.host) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 987bdc2add..9c67b53c9a 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -25,6 +25,7 @@ ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -55,12 +56,6 @@ class OpenAiChatPromptDriver(BasePromptDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), @@ -88,6 +83,15 @@ class OpenAiChatPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: openai.OpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI( + base_url=self.base_url, + api_key=self.api_key, + organization=self.organization, + ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py index 562a1d6376..0faf592f66 100644 --- a/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiTextToSpeechDriver +from griptape.utils.decorators import lazy_property @define @@ -35,17 +36,16 @@ class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py index f4be581621..ef6352cea6 100644 --- a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py @@ -1,27 +1,28 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + from elevenlabs.client import ElevenLabs @define class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) voice: str = field(kw_only=True, metadata={"serializable": True}) output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True}) + _client: ElevenLabs = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> ElevenLabs: + return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: audio = self.client.generate( diff --git a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py index 543ef1ec79..2522f0b3af 100644 --- a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py @@ -3,10 +3,11 @@ from typing import Literal, Optional import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver +from griptape.utils.decorators import lazy_property @define @@ -23,12 +24,15 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI( + api_key=self.api_key, + base_url=self.base_url, + organization=self.organization, + ) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: response = self.client.audio.speech.create( diff --git a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py index b1d8819587..465dfa4764 100644 --- a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from boto3 import Session - from opensearchpy import OpenSearch @define @@ -36,19 +35,6 @@ class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver): ), ) - client: OpenSearch = field( - default=Factory( - lambda self: import_optional_dependency("opensearchpy").OpenSearch( - hosts=[{"host": self.host, "port": self.port}], - http_auth=self.http_auth, - use_ssl=self.use_ssl, - verify_certs=self.verify_certs, - connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection, - ), - takes_self=True, - ), - ) - def upsert_vector( self, vector: list[float], diff --git a/griptape/drivers/vector/astradb_vector_store_driver.py b/griptape/drivers/vector/astradb_vector_store_driver.py index 029fa382d9..85832be00e 100644 --- a/griptape/drivers/vector/astradb_vector_store_driver.py +++ b/griptape/drivers/vector/astradb_vector_store_driver.py @@ -6,10 +6,11 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from astrapy import Collection - from astrapy.authentication import TokenProvider + import astrapy + import astrapy.authentication @define @@ -26,33 +27,35 @@ class AstraDbVectorStoreDriver(BaseVectorStoreDriver): It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values. astra_db_namespace: optional specification of the namespace (in the Astra database) for the data. *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store. + caller_name: the name of the caller for the Astra DB client. Defaults to "griptape". + client: an instance of `astrapy.DataAPIClient` for the Astra DB. + collection: an instance of `astrapy.Collection` for the Astra DB. """ api_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False}) + token: Optional[str | astrapy.authentication.TokenProvider] = field( + kw_only=True, default=None, metadata={"serializable": False} + ) collection_name: str = field(kw_only=True, metadata={"serializable": True}) environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - - collection: Collection = field(init=False) - - def __attrs_post_init__(self) -> None: - astrapy = import_optional_dependency("astrapy") - self.collection = ( - astrapy.DataAPIClient( - caller_name="griptape", - environment=self.environment, - ) - .get_database( - self.api_endpoint, - token=self.token, - namespace=self.astra_db_namespace, - ) - .get_collection( - name=self.collection_name, - ) + caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False}) + _client: astrapy.DataAPIClient = field(default=None, kw_only=True, metadata={"serializable": False}) + _collection: astrapy.Collection = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> astrapy.DataAPIClient: + return import_optional_dependency("astrapy").DataAPIClient( + caller_name=self.caller_name, + environment=self.environment, ) + @lazy_property() + def collection(self) -> astrapy.Collection: + return self.client.get_database( + self.api_endpoint, token=self.token, namespace=self.astra_db_namespace + ).get_collection(self.collection_name) + def delete_vector(self, vector_id: str) -> None: """Delete a vector from Astra DB store. diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index caab118b87..f520f65aad 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, Any, NoReturn, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape import utils from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import marqo @@ -21,20 +22,18 @@ class MarqoVectorStoreDriver(BaseVectorStoreDriver): Attributes: api_key: The API key for the Marqo API. url: The URL to the Marqo API. - mq: An optional Marqo client. Defaults to a new client with the given URL and API key. + client: An optional Marqo client. Defaults to a new client with the given URL and API key. index: The name of the index to use. """ api_key: str = field(kw_only=True, metadata={"serializable": True}) url: str = field(kw_only=True, metadata={"serializable": True}) - mq: Optional[marqo.Client] = field( - default=Factory( - lambda self: import_optional_dependency("marqo").Client(self.url, api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) index: str = field(kw_only=True, metadata={"serializable": True}) + _client: marqo.Client = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> marqo.Client: + return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key) def upsert_text( self, @@ -65,7 +64,7 @@ def upsert_text( if namespace: doc["namespace"] = namespace - response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description"]) + response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] else: @@ -102,7 +101,7 @@ def upsert_text_artifact( "namespace": namespace, } - response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"]) + response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] else: @@ -118,7 +117,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti Returns: The loaded Entry if found, otherwise None. """ - result = self.mq.index(self.index).get_document(document_id=vector_id, expose_facets=True) + result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True) if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0: return BaseVectorStoreDriver.Entry( @@ -141,15 +140,15 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto filter_string = f"namespace:{namespace}" if namespace else None if filter_string is not None: - results = self.mq.index(self.index).search("", limit=10000, filter_string=filter_string) + results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string) else: - results = self.mq.index(self.index).search("", limit=10000) + results = self.client.index(self.index).search("", limit=10000) # get all _id's from search results ids = [r["_id"] for r in results["hits"]] # get documents corresponding to the ids - documents = self.mq.index(self.index).get_documents(document_ids=ids, expose_facets=True) + documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True) # for each document, if it's found, create an Entry object entries = [] @@ -195,11 +194,12 @@ def query( "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs - results = self.mq.index(self.index).search(query, **params) + results = self.client.index(self.index).search(query, **params) if include_vectors: results["hits"] = [ - {**r, **self.mq.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"] + {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} + for r in results["hits"] ] return [ @@ -218,7 +218,7 @@ def delete_index(self, name: str) -> dict[str, Any]: Args: name: The name of the index to delete. """ - return self.mq.delete_index(name) + return self.client.delete_index(name) def get_indexes(self) -> list[str]: """Get a list of all indexes in the Marqo client. @@ -226,7 +226,7 @@ def get_indexes(self) -> list[str]: Returns: The list of all indexes. """ - return [index["index"] for index in self.mq.get_indexes()["results"]] + return [index["index"] for index in self.client.get_indexes()["results"]] def upsert_vector( self, diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index bc3f1e22f6..d61b2603cd 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from pymongo import MongoClient @@ -37,12 +38,11 @@ class MongoDbAtlasVectorStoreDriver(BaseVectorStoreDriver): kw_only=True, metadata={"serializable": True}, ) # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields - client: MongoClient = field( - default=Factory( - lambda self: import_optional_dependency("pymongo").MongoClient(self.connection_string), - takes_self=True, - ), - ) + _client: MongoClient = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> MongoClient: + return import_optional_dependency("pymongo").MongoClient(self.connection_string) def get_collection(self) -> Collection: """Returns the MongoDB Collection instance for the specified database and collection name.""" diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index cf944116a1..b33fd34ba3 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -3,11 +3,12 @@ import logging from typing import TYPE_CHECKING, NoReturn, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape import utils from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from opensearchpy import OpenSearch @@ -32,19 +33,19 @@ class OpenSearchVectorStoreDriver(BaseVectorStoreDriver): use_ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) verify_certs: bool = field(default=True, kw_only=True, metadata={"serializable": True}) index_name: str = field(kw_only=True, metadata={"serializable": True}) - - client: OpenSearch = field( - default=Factory( - lambda self: import_optional_dependency("opensearchpy").OpenSearch( - hosts=[{"host": self.host, "port": self.port}], - http_auth=self.http_auth, - use_ssl=self.use_ssl, - verify_certs=self.verify_certs, - connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection, - ), - takes_self=True, - ), - ) + _client: OpenSearch = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> OpenSearch: + opensearchpy = import_optional_dependency("opensearchpy") + + return opensearchpy.OpenSearch( + hosts=[{"host": self.host, "port": self.port}], + http_auth=self.http_auth, + use_ssl=self.use_ssl, + verify_certs=self.verify_certs, + connection_class=opensearchpy.RequestsHttpConnection, + ) def upsert_vector( self, diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index 30f437c7e4..1b2aa471db 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -9,9 +9,10 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from sqlalchemy.engine import Engine + import sqlalchemy @define @@ -27,14 +28,14 @@ class PgVectorVectorStoreDriver(BaseVectorStoreDriver): connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) - engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True, metadata={"serializable": True}) _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) + _sqlalchemy_engine: sqlalchemy.Engine = field(default=None, kw_only=True, metadata={"serializable": False}) @connection_string.validator # pyright: ignore[reportAttributeAccessIssue] def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. - if self.engine is not None: + if self._sqlalchemy_engine is not None: return # If an engine is not provided, a connection string is required. @@ -44,22 +45,11 @@ def validate_connection_string(self, _: Attribute, connection_string: Optional[s if not connection_string.startswith("postgresql://"): raise ValueError("The connection string must describe a Postgres database connection") - @engine.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_engine(self, _: Attribute, engine: Optional[Engine]) -> None: - # If a connection string is provided, an engine does not need to be provided. - if self.connection_string is not None: - return - - # If a connection string is not provided, an engine is required. - if engine is None: - raise ValueError("An engine or connection string is required") - - def __attrs_post_init__(self) -> None: - if self.engine is None: - if self.connection_string is None: - raise ValueError("An engine or connection string is required") - sqlalchemy = import_optional_dependency("sqlalchemy") - self.engine = sqlalchemy.create_engine(self.connection_string, **self.create_engine_params) + @lazy_property() + def sqlalchemy_engine(self) -> sqlalchemy.Engine: + return import_optional_dependency("sqlalchemy").create_engine( + self.connection_string, **self.create_engine_params + ) def setup( self, @@ -72,15 +62,15 @@ def setup( sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql") if install_uuid_extension: - with self.engine.begin() as conn: + with self.sqlalchemy_engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) if install_vector_extension: - with self.engine.begin() as conn: + with self.sqlalchemy_engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";')) if create_schema: - self._model.metadata.create_all(self.engine) + self._model.metadata.create_all(self.sqlalchemy_engine) def upsert_vector( self, @@ -94,7 +84,7 @@ def upsert_vector( """Inserts or updates a vector in the collection.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.engine) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs) obj = session.merge(obj) @@ -106,7 +96,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Base """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.engine) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: result = session.get(self._model, vector_id) return BaseVectorStoreDriver.Entry( @@ -120,7 +110,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.engine) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: query = session.query(self._model) if namespace: query = query.filter_by(namespace=namespace) @@ -161,7 +151,7 @@ def query( op = distance_metrics[distance_metric] - with sqlalchemy_orm.Session(self.engine) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: vector = self.embedding_driver.embed_string(query) # The query should return both the vector and the distance metric score. diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index a3a132ab36..500b090f58 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -6,6 +6,7 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency, str_to_hash +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import pinecone @@ -17,16 +18,20 @@ class PineconeVectorStoreDriver(BaseVectorStoreDriver): index_name: str = field(kw_only=True, metadata={"serializable": True}) environment: str = field(kw_only=True, metadata={"serializable": True}) project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - index: pinecone.Index = field(init=False) + _client: pinecone.Pinecone = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + _index: pinecone.Index = field(default=None, kw_only=True, alias="index", metadata={"serializable": False}) - def __attrs_post_init__(self) -> None: - pinecone = import_optional_dependency("pinecone").Pinecone( + @lazy_property() + def client(self) -> pinecone.Pinecone: + return import_optional_dependency("pinecone").Pinecone( api_key=self.api_key, environment=self.environment, project_name=self.project_name, ) - self.index = pinecone.Index(self.index_name) + @lazy_property() + def index(self) -> pinecone.Index: + return self.client.get_index(self.index_name) def upsert_vector( self, diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index 154e54af73..ea2232afac 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -2,12 +2,17 @@ import logging import uuid -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + from qdrant_client import QdrantClient + DEFAULT_DISTANCE = "Cosine" CONTENT_PAYLOAD_KEY = "data" @@ -56,9 +61,11 @@ class QdrantVectorStoreDriver(BaseVectorStoreDriver): collection_name: str = field(kw_only=True, metadata={"serializable": True}) vector_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True}) + _client: QdrantClient = field(default=None, kw_only=True, metadata={"serializable": False}) - def __attrs_post_init__(self) -> None: - self.client = import_optional_dependency("qdrant_client").QdrantClient( + @lazy_property() + def client(self) -> QdrantClient: + return import_optional_dependency("qdrant_client").QdrantClient( location=self.location, url=self.url, host=self.host, diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index 0abf2c9854..d8827a3d00 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, NoReturn, Optional import numpy as np -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency, str_to_hash +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from redis import Redis @@ -33,19 +34,17 @@ class RedisVectorStoreDriver(BaseVectorStoreDriver): db: int = field(kw_only=True, default=0, metadata={"serializable": True}) password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) index: str = field(kw_only=True, metadata={"serializable": True}) - - client: Redis = field( - default=Factory( - lambda self: import_optional_dependency("redis").Redis( - host=self.host, - port=self.port, - db=self.db, - password=self.password, - decode_responses=False, - ), - takes_self=True, - ), - ) + _client: Redis = field(default=None, kw_only=True, metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Redis: + return import_optional_dependency("redis").Redis( + host=self.host, + port=self.port, + db=self.db, + password=self.password, + decode_responses=False, + ) def upsert_vector( self, diff --git a/griptape/drivers/web_search/duck_duck_go_web_search_driver.py b/griptape/drivers/web_search/duck_duck_go_web_search_driver.py index b67e81f35c..96891c2d44 100644 --- a/griptape/drivers/web_search/duck_duck_go_web_search_driver.py +++ b/griptape/drivers/web_search/duck_duck_go_web_search_driver.py @@ -3,11 +3,12 @@ import json from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import BaseWebSearchDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from duckduckgo_search import DDGS @@ -15,7 +16,11 @@ @define class DuckDuckGoWebSearchDriver(BaseWebSearchDriver): - client: DDGS = field(default=Factory(lambda: import_optional_dependency("duckduckgo_search").DDGS()), kw_only=True) + _client: DDGS = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> DDGS: + return import_optional_dependency("duckduckgo_search").DDGS() def search(self, query: str, **kwargs) -> ListArtifact: try: diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 87020bd967..144c09d754 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.tokenizers import BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from google.generativeai import GenerativeModel @@ -17,16 +18,14 @@ class GoogleTokenizer(BaseTokenizer): MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"gemini": 2048} api_key: str = field(kw_only=True, metadata={"serializable": True}) - model_client: GenerativeModel = field( - default=Factory(lambda self: self._default_model_client(), takes_self=True), - kw_only=True, - ) + _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - def count_tokens(self, text: str) -> int: - return self.model_client.count_tokens(text).total_tokens - - def _default_model_client(self) -> GenerativeModel: + @lazy_property() + def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model) + + def count_tokens(self, text: str) -> int: + return self.client.count_tokens(text).total_tokens diff --git a/griptape/tools/aws_s3/tool.py b/griptape/tools/aws_s3/tool.py index 24d091d711..f83703500c 100644 --- a/griptape/tools/aws_s3/tool.py +++ b/griptape/tools/aws_s3/tool.py @@ -3,12 +3,12 @@ import io from typing import TYPE_CHECKING, Any -from attrs import Factory, define, field +from attrs import define, field from schema import Literal, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.tools import BaseAwsTool -from griptape.utils.decorators import activity +from griptape.utils.decorators import activity, lazy_property if TYPE_CHECKING: from mypy_boto3_s3 import Client @@ -16,7 +16,11 @@ @define class AwsS3Tool(BaseAwsTool): - s3_client: Client = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True) + _client: Client = field(default=None, kw_only=True) + + @lazy_property() + def client(self) -> Client: + return self.session.client("s3") @activity( config={ @@ -33,7 +37,7 @@ class AwsS3Tool(BaseAwsTool): ) def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact: try: - acl = self.s3_client.get_bucket_acl(Bucket=params["values"]["bucket_name"]) + acl = self.client.get_bucket_acl(Bucket=params["values"]["bucket_name"]) return TextArtifact(acl) except Exception as e: return ErrorArtifact(f"error getting bucket acl: {e}") @@ -48,7 +52,7 @@ def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact: ) def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact: try: - policy = self.s3_client.get_bucket_policy(Bucket=params["values"]["bucket_name"]) + policy = self.client.get_bucket_policy(Bucket=params["values"]["bucket_name"]) return TextArtifact(policy) except Exception as e: return ErrorArtifact(f"error getting bucket policy: {e}") @@ -66,7 +70,7 @@ def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact: ) def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact: try: - acl = self.s3_client.get_object_acl( + acl = self.client.get_object_acl( Bucket=params["values"]["bucket_name"], Key=params["values"]["object_key"], ) @@ -77,7 +81,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact: @activity(config={"description": "Can be used to list all AWS S3 buckets."}) def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact: try: - buckets = self.s3_client.list_buckets() + buckets = self.client.list_buckets() return ListArtifact([TextArtifact(str(b)) for b in buckets["Buckets"]]) except Exception as e: @@ -91,7 +95,7 @@ def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact: ) def list_objects(self, params: dict) -> ListArtifact | ErrorArtifact: try: - objects = self.s3_client.list_objects_v2(Bucket=params["values"]["bucket_name"]) + objects = self.client.list_objects_v2(Bucket=params["values"]["bucket_name"]) if "Contents" not in objects: return ErrorArtifact("no objects found in the bucket") @@ -192,7 +196,7 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact: artifacts = [] for object_info in objects: try: - obj = self.s3_client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"]) + obj = self.client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"]) content = obj["Body"].read() artifacts.append(BlobArtifact(content, name=object_info["object_key"])) @@ -203,9 +207,9 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact: return ListArtifact(artifacts) def _upload_object(self, bucket_name: str, object_name: str, value: Any) -> None: - self.s3_client.create_bucket(Bucket=bucket_name) + self.client.create_bucket(Bucket=bucket_name) - self.s3_client.upload_fileobj( + self.client.upload_fileobj( Fileobj=io.BytesIO(value.encode() if isinstance(value, str) else value), Bucket=bucket_name, Key=object_name, diff --git a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py index 50856c0da1..4168f762d6 100644 --- a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py @@ -9,11 +9,11 @@ class TestPusherEventListenerDriver: @pytest.fixture(autouse=True) def mock_post(self, mocker): - mock_pusher_client = mocker.patch("pusher.Pusher") - mock_pusher_client.return_value.trigger.return_value = Mock() - mock_pusher_client.return_value.trigger_batch.return_value = Mock() + mock_client = mocker.patch("pusher.Pusher") + mock_client.return_value.trigger.return_value = Mock() + mock_client.return_value.trigger_batch.return_value = Mock() - return mock_pusher_client + return mock_client @pytest.fixture() def driver(self): @@ -33,12 +33,12 @@ def test_try_publish_event_payload(self, driver): data = MockEvent().to_dict() driver.try_publish_event_payload(data) - driver.pusher_client.trigger.assert_called_with(channels="test-channel", event_name="test-event", data=data) + driver.client.trigger.assert_called_with(channels="test-channel", event_name="test-event", data=data) def test_try_publish_event_payload_batch(self, driver): data = [MockEvent().to_dict() for _ in range(3)] driver.try_publish_event_payload_batch(data) - driver.pusher_client.trigger_batch.assert_called_with( + driver.client.trigger_batch.assert_called_with( [{"channel": "test-channel", "name": "test-event", "data": data[i]} for i in range(3)] ) diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index 05e669b661..dae6f695b8 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -8,13 +8,13 @@ class TestAmazonBedrockImageGenerationDriver: @pytest.fixture() - def bedrock_client(self): + def client(self): return Mock() @pytest.fixture() - def session(self, bedrock_client): + def session(self, client): session = Mock() - session.client.return_value = bedrock_client + session.client.return_value = client return session @@ -40,7 +40,7 @@ def test_init_requires_image_generation_model_driver(self, session): AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore[reportCallIssue] def test_try_text_to_image(self, driver): - driver.bedrock_client.invoke_model.return_value = { + driver.client.invoke_model.return_value = { "body": io.BytesIO( b"""{ "artifacts": [ diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py index 9493ab23de..66b23d0c38 100644 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py @@ -9,13 +9,13 @@ class TestAmazonBedrockImageQueryDriver: @pytest.fixture() - def bedrock_client(self, mocker): + def client(self, mocker): return Mock() @pytest.fixture() - def session(self, bedrock_client): + def session(self, client): session = Mock() - session.client.return_value = bedrock_client + session.client.return_value = client return session @@ -35,7 +35,7 @@ def test_init(self, image_query_driver): assert image_query_driver def test_try_query(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} + image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} text_artifact = image_query_driver.try_query( "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] @@ -44,7 +44,7 @@ def test_try_query(self, image_query_driver): assert text_artifact.value == "content" def test_try_query_no_body(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"")} + image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"")} with pytest.raises(ValueError): image_query_driver.try_query( diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index 254a2b3a16..2d824e1c29 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -86,7 +86,7 @@ def driver(self, mock_marqo): api_key="foobar", url="http://localhost:8000", index="test", - mq=mock_marqo, + client=mock_marqo, embedding_driver=MockEmbeddingDriver(), ) diff --git a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py index 72833f5a36..b46ce64512 100644 --- a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py @@ -36,7 +36,9 @@ def test_initialize_requires_engine_or_connection_string(self, embedding_driver) def test_initialize_accepts_engine(self, embedding_driver): engine: Any = create_engine(self.connection_string) - PgVectorVectorStoreDriver(embedding_driver=embedding_driver, engine=engine, table_name=self.table_name) + PgVectorVectorStoreDriver( + embedding_driver=embedding_driver, sqlalchemy_engine=engine, table_name=self.table_name + ) def test_initialize_accepts_connection_string(self, embedding_driver): PgVectorVectorStoreDriver( @@ -48,7 +50,7 @@ def test_upsert_vector(self, mock_session, mock_engine): mock_session.merge.return_value = Mock(id=test_id) driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) returned_id = driver.upsert_vector([1.0, 2.0, 3.0]) @@ -65,7 +67,7 @@ def test_load_entry(self, mock_session, mock_engine): mock_session.get.return_value = Mock(id=test_id, vector=test_vec, namespace=test_namespace, meta=test_meta) driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) entry = driver.load_entry(vector_id=test_id) @@ -88,7 +90,7 @@ def test_load_entries(self, mock_session, mock_engine): mock_session.query.return_value = mock_query driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) entries = driver.load_entries() @@ -104,7 +106,7 @@ def test_load_entries(self, mock_session, mock_engine): def test_query_invalid_distance_metric(self, mock_engine): driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) with pytest.raises(ValueError): @@ -122,7 +124,7 @@ def test_query(self, mock_session, mock_engine): mock_session.query().order_by().limit().all.return_value = test_result driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) result = driver.query("some query", include_vectors=True) @@ -147,7 +149,7 @@ def test_query_filter(self, mock_session, mock_engine): mock_session.query().order_by().filter_by().limit().all.return_value = test_result driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), engine=mock_engine, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) result = driver.query("some query", include_vectors=True, filter={"namespace": test_namespaces[0]}) diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 0726a0c7ea..a963fb370e 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -7,7 +7,7 @@ class TestPineconeVectorStorageDriver: @pytest.fixture(autouse=True) - def _mock_pinecone(self, mocker): + def mock_index(self, mocker): # Create a fake response fake_query_response = { "matches": [{"id": "foo", "values": [0, 1, 0], "score": 42, "metadata": {"foo": "bar"}}], @@ -15,14 +15,21 @@ def _mock_pinecone(self, mocker): } mock_client = mocker.patch("pinecone.Pinecone") - mock_client().Index().upsert.return_value = None - mock_client().Index().query.return_value = fake_query_response - mock_client().create_index.return_value = None + mock_index = mock_client().Index() + mock_index.upsert.return_value = None + mock_index.query.return_value = fake_query_response + mock_index.create_index.return_value = None + + return mock_index @pytest.fixture() - def driver(self): + def driver(self, mock_index): return PineconeVectorStoreDriver( - api_key="foobar", index_name="test", environment="test", embedding_driver=MockEmbeddingDriver() + api_key="foobar", + index_name="test", + environment="test", + embedding_driver=MockEmbeddingDriver(), + index=mock_index, ) def test_upsert_text_artifact(self, driver): diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index ffb3599533..3c14f23962 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -37,30 +37,6 @@ def driver(self, embedding_driver, mocker): embedding_driver=embedding_driver, ) - def test_attrs_post_init(self, driver): - with patch("griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency") as mock_import: - mock_qdrant_client = MagicMock() - mock_import.return_value.QdrantClient.return_value = mock_qdrant_client - - driver.__attrs_post_init__() - - mock_import.assert_called_once_with("qdrant_client") - mock_import.return_value.QdrantClient.assert_called_once_with( - location=driver.location, - url=driver.url, - host=driver.host, - path=driver.path, - port=driver.port, - prefer_grpc=driver.prefer_grpc, - grpc_port=driver.grpc_port, - api_key=driver.api_key, - https=driver.https, - prefix=driver.prefix, - force_disable_check_same_thread=driver.force_disable_check_same_thread, - timeout=driver.timeout, - ) - assert driver.client == mock_qdrant_client - def test_delete_vector(self, driver): vector_id = "test_vector_id"