diff --git a/CHANGELOG.md b/CHANGELOG.md index d3d8ddcf4..289ec72d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## Added +- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s. + +### Changed +- **BREAKING**: Renamed parameters on several classes to `client`: + - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver`. + - `bedrock_client` on `AmazonBedrockImageGenerationDriver`. + - `bedrock_client` on `AmazonBedrockImageQueryDriver`. + - `bedrock_client` on `AmazonBedrockPromptDriver`. + - `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver`. + - `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver`. + - `sqs_client` on `AmazonSqsEventListenerDriver`. + - `iotdata_client` on `AwsIotCoreEventListenerDriver`. + - `s3_client` on `AmazonS3FileManagerDriver`. + - `s3_client` on `AwsS3Tool`. + - `iam_client` on `AwsIamTool`. + - `pusher_client` on `PusherEventListenerDriver`. + - `mq` on `MarqoVectorStoreDriver`. + - `model_client` on `GooglePromptDriver`. + - `model_client` on `GoogleTokenizer`. +- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`. +- Several places where API clients are initialized are now lazy loaded. + + ## [0.32.0] - 2024-09-17 ### Added diff --git a/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py index 312fa8318..f81031897 100644 --- a/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py +++ b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py @@ -4,10 +4,11 @@ from typing import Optional import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import AudioArtifact, TextArtifact from griptape.drivers import BaseAudioTranscriptionDriver +from griptape.utils.decorators import lazy_property @define @@ -17,12 +18,11 @@ class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: additional_params = {} diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 4e4f4aa31..c1b2069c8 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -1,16 +1,18 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_bedrock import BedrockClient from griptape.tokenizers.base_tokenizer import BaseTokenizer @@ -26,7 +28,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): `search_query` when querying your vector DB to find relevant documents. session: Optionally provide custom `boto3.Session`. tokenizer: Optionally provide custom `BedrockCohereTokenizer`. - bedrock_client: Optionally provide custom `bedrock-runtime` client. + client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "cohere.embed-english-v3" @@ -38,15 +40,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> BedrockClient: + return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"input_type": self.input_type, "texts": [chunk]} - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="*/*", diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index 5900d7d86..a17af9aee 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -1,16 +1,18 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_bedrock import BedrockClient from griptape.tokenizers.base_tokenizer import BaseTokenizer @@ -23,7 +25,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): model: Embedding model name. Defaults to DEFAULT_MODEL. tokenizer: Optionally provide custom `BedrockTitanTokenizer`. session: Optionally provide custom `boto3.Session`. - bedrock_client: Optionally provide custom `bedrock-runtime` client. + client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "amazon.titan-embed-text-v1" @@ -34,15 +36,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> BedrockClient: + return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"inputText": chunk} - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="application/json", diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py index c4feb8a1d..c047236de 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -1,32 +1,35 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_sagemaker import SageMakerClient @define class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), - kw_only=True, - ) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + _client: SageMakerClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> SageMakerClient: + return self.session.client("sagemaker-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} - endpoint_response = self.sagemaker_client.invoke_endpoint( + endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), diff --git a/griptape/drivers/embedding/azure_openai_embedding_driver.py b/griptape/drivers/embedding/azure_openai_embedding_driver.py index c1e601aef..366a91460 100644 --- a/griptape/drivers/embedding/azure_openai_embedding_driver.py +++ b/griptape/drivers/embedding/azure_openai_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import OpenAiEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer +from griptape.utils.decorators import lazy_property @define @@ -40,17 +41,16 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver): default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/embedding/cohere_embedding_driver.py b/griptape/drivers/embedding/cohere_embedding_driver.py index 365dc972e..42e89ff70 100644 --- a/griptape/drivers/embedding/cohere_embedding_driver.py +++ b/griptape/drivers/embedding/cohere_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import CohereTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from cohere import Client @@ -27,16 +28,16 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver): DEFAULT_MODEL = "models/embedding-001" api_key: str = field(kw_only=True, metadata={"serializable": False}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), - kw_only=True, - ) + input_type: str = field(kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: CohereTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True, ) - input_type: str = field(kw_only=True, metadata={"serializable": True}) + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("cohere").Client(self.api_key) def try_embed_chunk(self, chunk: str) -> list[float]: result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index c1be2ec96..573bfc379 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from huggingface_hub import InferenceClient @@ -22,16 +23,14 @@ class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver): """ api_token: str = field(kw_only=True, metadata={"serializable": True}) - client: InferenceClient = field( - default=Factory( - lambda self: import_optional_dependency("huggingface_hub").InferenceClient( - model=self.model, - token=self.api_token, - ), - takes_self=True, - ), - kw_only=True, - ) + _client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> InferenceClient: + return import_optional_dependency("huggingface_hub").InferenceClient( + model=self.model, + token=self.api_token, + ) def try_embed_chunk(self, chunk: str) -> list[float]: response = self.client.feature_extraction(chunk) diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py index c5c30d5af..1b32a21f3 100644 --- a/griptape/drivers/embedding/ollama_embedding_driver.py +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEmbeddingDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from ollama import Client @@ -23,10 +24,11 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver): model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), - kw_only=True, - ) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("ollama").Client(host=self.host) def try_embed_chunk(self, chunk: str) -> list[float]: return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/griptape/drivers/embedding/openai_embedding_driver.py b/griptape/drivers/embedding/openai_embedding_driver.py index 0995fba68..b0b799790 100644 --- a/griptape/drivers/embedding/openai_embedding_driver.py +++ b/griptape/drivers/embedding/openai_embedding_driver.py @@ -7,6 +7,7 @@ from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import OpenAiTokenizer +from griptape.utils.decorators import lazy_property @define @@ -33,16 +34,15 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) tokenizer: OpenAiTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_embed_chunk(self, chunk: str) -> list[float]: # Address a performance issue in older ada models diff --git a/griptape/drivers/embedding/voyageai_embedding_driver.py b/griptape/drivers/embedding/voyageai_embedding_driver.py index c5e418ed1..bc4e78bf1 100644 --- a/griptape/drivers/embedding/voyageai_embedding_driver.py +++ b/griptape/drivers/embedding/voyageai_embedding_driver.py @@ -1,12 +1,16 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from attrs import Factory, define, field from griptape.drivers import BaseEmbeddingDriver from griptape.tokenizers import VoyageAiTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + import voyageai @define @@ -25,17 +29,16 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver): model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key), - takes_self=True, - ), - ) tokenizer: VoyageAiTokenizer = field( default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True, ) input_type: str = field(default="document", kw_only=True, metadata={"serializable": True}) + _client: voyageai.Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return import_optional_dependency("voyageai").Client(api_key=self.api_key) def try_embed_chunk(self, chunk: str) -> list[float]: return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0] diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index 4c632cb01..9030f5d77 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -1,25 +1,31 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_sqs import SQSClient @define class AmazonSqsEventListenerDriver(BaseEventListenerDriver): queue_url: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True)) + _client: SQSClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> SQSClient: + return self.session.client("sqs") def try_publish_event_payload(self, event_payload: dict) -> None: - self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) + self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: entries = [ @@ -27,4 +33,4 @@ def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> No for event_payload in event_payload_batch ] - self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) + self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py index 3b014aed4..c3a5a55e7 100644 --- a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -1,15 +1,17 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_iot_data import IoTDataPlaneClient @define @@ -17,10 +19,14 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): iot_endpoint: str = field(kw_only=True) topic: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True)) + _client: IoTDataPlaneClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> IoTDataPlaneClient: + return self.session.client("iot-data") def try_publish_event_payload(self, event_payload: dict) -> None: - self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload)) + self.client.publish(topic=self.topic, payload=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: - self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) + self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch)) diff --git a/griptape/drivers/event_listener/pusher_event_listener_driver.py b/griptape/drivers/event_listener/pusher_event_listener_driver.py index ce9a4fb34..33d160b46 100644 --- a/griptape/drivers/event_listener/pusher_event_listener_driver.py +++ b/griptape/drivers/event_listener/pusher_event_listener_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseEventListenerDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from pusher import Pusher @@ -13,25 +14,24 @@ @define class PusherEventListenerDriver(BaseEventListenerDriver): - app_id: str = field(kw_only=True) - key: str = field(kw_only=True) - secret: str = field(kw_only=True) - cluster: str = field(kw_only=True) - channel: str = field(kw_only=True) - event_name: str = field(kw_only=True) - pusher_client: Pusher = field( - default=Factory( - lambda self: import_optional_dependency("pusher").Pusher( - app_id=self.app_id, - key=self.key, - secret=self.secret, - cluster=self.cluster, - ssl=True, - ), - takes_self=True, - ), - kw_only=True, - ) + app_id: str = field(kw_only=True, metadata={"serializable": True}) + key: str = field(kw_only=True, metadata={"serializable": True}) + secret: str = field(kw_only=True, metadata={"serializable": False}) + cluster: str = field(kw_only=True, metadata={"serializable": True}) + channel: str = field(kw_only=True, metadata={"serializable": True}) + event_name: str = field(kw_only=True, metadata={"serializable": True}) + ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Pusher = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Pusher: + return import_optional_dependency("pusher").Pusher( + app_id=self.app_id, + key=self.key, + secret=self.secret, + cluster=self.cluster, + ssl=self.ssl, + ) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: data = [ @@ -39,7 +39,7 @@ def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> No for event_payload in event_payload_batch ] - self.pusher_client.trigger_batch(data) + self.client.trigger_batch(data) def try_publish_event_payload(self, event_payload: dict) -> None: - self.pusher_client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload) + self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload) diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index 20e432c0b..1e841866a 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -1,15 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Attribute, Factory, define, field +from griptape.utils.decorators import lazy_property from griptape.utils.import_utils import import_optional_dependency from .base_file_manager_driver import BaseFileManagerDriver if TYPE_CHECKING: import boto3 + from mypy_boto3_s3 import S3Client @define @@ -21,13 +23,17 @@ class AmazonS3FileManagerDriver(BaseFileManagerDriver): bucket: The name of the S3 bucket. workdir: The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory. - s3_client: The S3 client to use for S3 operations. + client: The S3 client to use for S3 operations. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) bucket: str = field(kw_only=True) workdir: str = field(default="/", kw_only=True) - s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True) + _client: S3Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> S3Client: + return self.session.client("s3") @workdir.validator # pyright: ignore[reportAttributeAccessIssue] def validate_workdir(self, _: Attribute, workdir: str) -> None: @@ -51,7 +57,7 @@ def try_load_file(self, path: str) -> bytes: raise IsADirectoryError try: - response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key) + response = self.client.get_object(Bucket=self.bucket, Key=full_key) return response["Body"].read() except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: @@ -62,7 +68,7 @@ def try_save_file(self, path: str, value: bytes) -> None: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError - self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value) + self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value) def _to_full_key(self, path: str) -> str: path = path.lstrip("/") @@ -90,7 +96,7 @@ def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: if max_items is not None: pagination_config["MaxItems"] = max_items - paginator = self.s3_client.get_paginator("list_objects_v2") + paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate( Bucket=self.bucket, Prefix=full_key, @@ -116,7 +122,7 @@ def _is_a_directory(self, full_key: str) -> bool: return True try: - self.s3_client.head_object(Bucket=self.bucket, Key=full_key) + self.client.head_object(Bucket=self.bucket, Key=full_key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: return len(self._list_files_and_dirs(full_key, max_items=1)) > 0 diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 4db302f6f..3e69036f6 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -1,16 +1,18 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from attrs import Factory, define, field from griptape.artifacts import ImageArtifact from griptape.drivers import BaseMultiModelImageGenerationDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_bedrock import BedrockClient @define @@ -20,19 +22,21 @@ class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver): Attributes: model: Bedrock model ID. session: boto3 session. - bedrock_client: Bedrock runtime client. + client: Bedrock runtime client. image_width: Width of output images. Defaults to 512 and must be a multiple of 64. image_height: Height of output images. Defaults to 512 and must be a multiple of 64. seed: Optionally provide a consistent seed to generation requests, increasing consistency in output. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True), - ) image_width: int = field(default=512, kw_only=True, metadata={"serializable": True}) image_height: int = field(default=512, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) + _client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> BedrockClient: + return self.session.client("bedrock-runtime") def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: request = self.image_generation_model_driver.text_to_image_request_parameters( @@ -127,7 +131,7 @@ def try_image_outpainting( ) def _make_request(self, request: dict) -> bytes: - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( body=json.dumps(request), modelId=self.model, accept="application/json", diff --git a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py index 85facda4c..2555fcfd0 100644 --- a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiImageGenerationDriver +from griptape.utils.decorators import lazy_property @define @@ -34,17 +35,16 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/image_generation/openai_image_generation_driver.py b/griptape/drivers/image_generation/openai_image_generation_driver.py index bf77ac300..ec8129e89 100644 --- a/griptape/drivers/image_generation/openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/openai_image_generation_driver.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, Literal, Optional, Union, cast import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationDriver +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from openai.types.images_response import ImagesResponse @@ -38,12 +39,6 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) quality: Union[Literal["standard"], Literal["hd"]] = field( default="standard", @@ -58,6 +53,11 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver): Literal["1792x1024"], ] = field(default="1024x1024", kw_only=True, metadata={"serializable": True}) response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True, metadata={"serializable": True}) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: prompt = ", ".join(prompts) diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py index 46406d972..9742cb9c7 100644 --- a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py +++ b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py @@ -1,15 +1,17 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.drivers import BaseMultiModelImageQueryDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_bedrock import BedrockClient from griptape.artifacts import ImageArtifact, TextArtifact @@ -17,15 +19,16 @@ @define class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) + _client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> BedrockClient: + return self.session.client("bedrock-runtime") def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) - response = self.bedrock_client.invoke_model( + response = self.client.invoke_model( modelId=self.model, contentType="application/json", accept="application/json", diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py index a50685724..191d95373 100644 --- a/griptape/drivers/image_query/anthropic_image_query_driver.py +++ b/griptape/drivers/image_query/anthropic_image_query_driver.py @@ -1,12 +1,16 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers import BaseImageQueryDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + from anthropic import Anthropic @define @@ -21,13 +25,11 @@ class AnthropicImageQueryDriver(BaseImageQueryDriver): api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) + _client: Anthropic = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Anthropic: + return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: if self.max_tokens is None: diff --git a/griptape/drivers/image_query/azure_openai_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py index 04492e471..637fa11cc 100644 --- a/griptape/drivers/image_query/azure_openai_image_query_driver.py +++ b/griptape/drivers/image_query/azure_openai_image_query_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver +from griptape.utils.decorators import lazy_property @define @@ -34,17 +35,16 @@ class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/image_query/openai_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py index 6399efa95..f0ef9e148 100644 --- a/griptape/drivers/image_query/openai_image_query_driver.py +++ b/griptape/drivers/image_query/openai_image_query_driver.py @@ -3,7 +3,7 @@ from typing import Literal, Optional import openai -from attrs import Factory, define, field +from attrs import define, field from openai.types.chat import ( ChatCompletionContentPartImageParam, ChatCompletionContentPartParam, @@ -13,6 +13,7 @@ from griptape.artifacts import ImageArtifact, TextArtifact from griptape.drivers.image_query.base_image_query_driver import BaseImageQueryDriver +from griptape.utils.decorators import lazy_property @define @@ -24,12 +25,11 @@ class OpenAiImageQueryDriver(BaseImageQueryDriver): api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: message_parts: list[ChatCompletionContentPartParam] = [ diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index 0842870eb..47ea13e0a 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -7,9 +7,11 @@ from griptape.drivers import BaseConversationMemoryDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_dynamodb.service_resource import Table from griptape.memory.structure import Run @@ -23,11 +25,11 @@ class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): partition_key_value: str = field(kw_only=True, metadata={"serializable": True}) sort_key: Optional[str] = field(default=None, metadata={"serializable": True}) sort_key_value: Optional[str | int] = field(default=None, metadata={"serializable": True}) + _table: Table = field(default=None, kw_only=True, alias="table", metadata={"serializable": False}) - table: Any = field(init=False) - - def __attrs_post_init__(self) -> None: - self.table = self.session.resource("dynamodb").Table(self.table_name) + @lazy_property() + def table(self) -> Table: + return self.session.resource("dynamodb").Table(self.table_name) def store(self, runs: list[Run], metadata: dict) -> None: self.table.update_item( diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index bc339f618..be34d2a8c 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -31,6 +31,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -44,10 +45,6 @@ @define class AmazonBedrockPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - bedrock_client: Any = field( - default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), - kw_only=True, - ) additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), @@ -55,10 +52,15 @@ class AmazonBedrockPromptDriver(BasePromptDriver): ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("bedrock-runtime") @observable def try_run(self, prompt_stack: PromptStack) -> Message: - response = self.bedrock_client.converse(**self._base_params(prompt_stack)) + response = self.client.converse(**self._base_params(prompt_stack)) usage = response["usage"] output_message = response["output"]["message"] @@ -71,7 +73,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) + response = self.client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") if stream is not None: diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index d7a2f5b0b..2dcf55307 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -10,6 +10,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -22,10 +23,6 @@ @define class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), - kw_only=True, - ) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) @@ -38,6 +35,11 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Any: + return self.session.client("sagemaker-runtime") @stream.validator # pyright: ignore[reportAttributeAccessIssue] def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001 @@ -51,7 +53,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: "parameters": {**self._base_params(prompt_stack)}, } - response = self.sagemaker_client.invoke_endpoint( + response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload), diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index ae50bc59e..8c944b2cc 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -32,6 +32,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -54,13 +55,6 @@ class AnthropicPromptDriver(BasePromptDriver): api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory( - lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True, @@ -70,6 +64,11 @@ class AnthropicPromptDriver(BasePromptDriver): tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index b08b51b69..5bb7e0760 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiChatPromptDriver +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from griptape.common import PromptStack @@ -37,20 +38,19 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): metadata={"serializable": False}, ) api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index ff1a8b482..b31c78ea3 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -23,6 +23,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, CohereTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -45,14 +46,16 @@ class CoherePromptDriver(BasePromptDriver): api_key: str = field(metadata={"serializable": False}) model: str = field(metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), - ) + force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), ) - force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("cohere").Client(self.api_key) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 6b18f6041..4afdad5c6 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -26,6 +26,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, GoogleTokenizer from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -44,17 +45,13 @@ class GooglePromptDriver(BasePromptDriver): Attributes: api_key: Google API key. model: Google model name. - model_client: Custom `GenerativeModel` client. + client: Custom `GenerativeModel` client. top_p: Optional value for top_p. top_k: Optional value for top_k. """ api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) - model_client: GenerativeModel = field( - default=Factory(lambda self: self._default_model_client(), takes_self=True), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, @@ -63,11 +60,19 @@ class GooglePromptDriver(BasePromptDriver): top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) + _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> GenerativeModel: + genai = import_optional_dependency("google.generativeai") + genai.configure(api_key=self.api_key) + + return genai.GenerativeModel(self.model) @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) - response: GenerateContentResponse = self.model_client.generate_content( + response: GenerateContentResponse = self.client.generate_content( messages, **self._base_params(prompt_stack), ) @@ -86,7 +91,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message: @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) - response: GenerateContentResponse = self.model_client.generate_content( + response: GenerateContentResponse = self.client.generate_content( messages, **self._base_params(prompt_stack), stream=True, @@ -119,7 +124,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = prompt_stack.system_messages if system_messages: - self.model_client._system_instruction = types.ContentDict( + self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) @@ -146,12 +151,6 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: ), } - def _default_model_client(self) -> GenerativeModel: - genai = import_optional_dependency("google.generativeai") - genai.configure(api_key=self.api_key) - - return genai.GenerativeModel(self.model) - def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 657b5747c..68267f755 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -8,6 +8,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -32,16 +33,6 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) - client: InferenceClient = field( - default=Factory( - lambda self: import_optional_dependency("huggingface_hub").InferenceClient( - model=self.model, - token=self.api_token, - ), - takes_self=True, - ), - kw_only=True, - ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), @@ -49,6 +40,14 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> InferenceClient: + return import_optional_dependency("huggingface_hub").InferenceClient( + model=self.model, + token=self.api_token, + ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 128167f52..1978b339a 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -9,6 +9,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -35,23 +36,24 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) - pipe: TextGenerationPipeline = field( - default=Factory( - lambda self: import_optional_dependency("transformers").pipeline( - "text-generation", - model=self.model, - max_new_tokens=self.max_tokens, - tokenizer=self.tokenizer.tokenizer, - ), - takes_self=True, - ), + _pipeline: TextGenerationPipeline = field( + default=None, kw_only=True, alias="pipeline", metadata={"serializable": False} ) + @lazy_property() + def pipeline(self) -> TextGenerationPipeline: + return import_optional_dependency("transformers").pipeline( + task="text-generation", + model=self.model, + max_new_tokens=self.max_tokens, + tokenizer=self.tokenizer.tokenizer, + ) + @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) - result = self.pipe( + result = self.pipeline( messages, max_new_tokens=self.max_tokens, temperature=self.temperature, diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 70d4ce89a..5f9e32e2f 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -22,6 +22,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import SimpleTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from ollama import Client @@ -40,10 +41,6 @@ class OllamaPromptDriver(BasePromptDriver): model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: Client = field( - default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), - kw_only=True, - ) tokenizer: BaseTokenizer = field( default=Factory( lambda self: SimpleTokenizer( @@ -67,6 +64,11 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Client: + return import_optional_dependency("ollama").Client(host=self.host) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 987bdc2ad..bab20d3f0 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -25,6 +25,7 @@ ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from collections.abc import Iterator @@ -55,12 +56,6 @@ class OpenAiChatPromptDriver(BasePromptDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), @@ -88,6 +83,15 @@ class OpenAiChatPromptDriver(BasePromptDriver): ), kw_only=True, ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI( + base_url=self.base_url, + api_key=self.api_key, + organization=self.organization, + ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: diff --git a/griptape/drivers/sql/amazon_redshift_sql_driver.py b/griptape/drivers/sql/amazon_redshift_sql_driver.py index 837405e83..8e5d912c8 100644 --- a/griptape/drivers/sql/amazon_redshift_sql_driver.py +++ b/griptape/drivers/sql/amazon_redshift_sql_driver.py @@ -3,12 +3,14 @@ import time from typing import TYPE_CHECKING, Any, Optional -from attrs import Attribute, Factory, define, field +from attrs import Attribute, define, field from griptape.drivers import BaseSqlDriver +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import boto3 + from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient @define @@ -20,11 +22,14 @@ class AmazonRedshiftSqlDriver(BaseSqlDriver): db_user: Optional[str] = field(default=None, kw_only=True) database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) - client: Any = field( - default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), - kw_only=True, + _client: RedshiftDataAPIServiceClient = field( + default=None, kw_only=True, alias="client", metadata={"serializable": False} ) + @lazy_property() + def client(self) -> RedshiftDataAPIServiceClient: + return self.session.client("redshift-data") + @workgroup_name.validator # pyright: ignore[reportAttributeAccessIssue] def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: diff --git a/griptape/drivers/sql/snowflake_sql_driver.py b/griptape/drivers/sql/snowflake_sql_driver.py index 656bc4b99..d1b4310b5 100644 --- a/griptape/drivers/sql/snowflake_sql_driver.py +++ b/griptape/drivers/sql/snowflake_sql_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Any, Callable, Optional -from attrs import Attribute, Factory, define, field +from attrs import Attribute, define, field from griptape.drivers import BaseSqlDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from snowflake.connector import SnowflakeConnection @@ -15,18 +16,7 @@ @define class SnowflakeSqlDriver(BaseSqlDriver): connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True) - engine: Engine = field( - default=Factory( - # Creator bypasses the URL param - # https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.creator - lambda self: import_optional_dependency("sqlalchemy").create_engine( - "snowflake://not@used/db", - creator=self.connection_func, - ), - takes_self=True, - ), - kw_only=True, - ) + _engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) @connection_func.validator # pyright: ignore[reportFunctionMemberAccess] def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None: @@ -38,10 +28,12 @@ def validate_connection_func(self, _: Attribute, connection_func: Callable[[], S if not snowflake_connection.schema or not snowflake_connection.database: raise ValueError("Provide a schema and database for the Snowflake connection") - @engine.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_engine_url(self, _: Attribute, engine: Engine) -> None: - if not engine.url.render_as_string().startswith("snowflake://"): - raise ValueError("Provide a Snowflake connection") + @lazy_property() + def engine(self) -> Engine: + return import_optional_dependency("sqlalchemy").create_engine( + "snowflake://not@used/db", + creator=self.connection_func, + ) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) diff --git a/griptape/drivers/sql/sql_driver.py b/griptape/drivers/sql/sql_driver.py index d2293f94d..cb7a67341 100644 --- a/griptape/drivers/sql/sql_driver.py +++ b/griptape/drivers/sql/sql_driver.py @@ -6,6 +6,7 @@ from griptape.drivers import BaseSqlDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -15,12 +16,11 @@ class SqlDriver(BaseSqlDriver): engine_url: str = field(kw_only=True) create_engine_params: dict = field(factory=dict, kw_only=True) - engine: Engine = field(init=False) + _engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) - def __attrs_post_init__(self) -> None: - sqlalchemy = import_optional_dependency("sqlalchemy") - - self.engine = sqlalchemy.create_engine(self.engine_url, **self.create_engine_params) + @lazy_property() + def engine(self) -> Engine: + return import_optional_dependency("sqlalchemy").create_engine(self.engine_url, **self.create_engine_params) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) diff --git a/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py index 562a1d637..f64ab0e2d 100644 --- a/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py @@ -6,6 +6,7 @@ from attrs import Factory, define, field from griptape.drivers import OpenAiTextToSpeechDriver +from griptape.utils.decorators import lazy_property @define @@ -35,17 +36,16 @@ class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver): metadata={"serializable": False}, ) api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ), - ) + _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.AzureOpenAI: + return openai.AzureOpenAI( + organization=self.organization, + api_key=self.api_key, + api_version=self.api_version, + azure_endpoint=self.azure_endpoint, + azure_deployment=self.azure_deployment, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) diff --git a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py index f4be58162..ef6352cea 100644 --- a/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py @@ -1,27 +1,28 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + from elevenlabs.client import ElevenLabs @define class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) - client: Any = field( - default=Factory( - lambda self: import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - metadata={"serializable": True}, - ) voice: str = field(kw_only=True, metadata={"serializable": True}) output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True}) + _client: ElevenLabs = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> ElevenLabs: + return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: audio = self.client.generate( diff --git a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py index 543ef1ec7..558e2f875 100644 --- a/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py +++ b/griptape/drivers/text_to_speech/openai_text_to_speech_driver.py @@ -3,10 +3,11 @@ from typing import Literal, Optional import openai -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact from griptape.drivers import BaseTextToSpeechDriver +from griptape.utils.decorators import lazy_property @define @@ -23,12 +24,15 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver): base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ), - ) + _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> openai.OpenAI: + return openai.OpenAI( + api_key=self.api_key, + base_url=self.base_url, + organization=self.organization, + ) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: response = self.client.audio.speech.create( diff --git a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py index b1d881958..465dfa476 100644 --- a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from boto3 import Session - from opensearchpy import OpenSearch @define @@ -36,19 +35,6 @@ class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver): ), ) - client: OpenSearch = field( - default=Factory( - lambda self: import_optional_dependency("opensearchpy").OpenSearch( - hosts=[{"host": self.host, "port": self.port}], - http_auth=self.http_auth, - use_ssl=self.use_ssl, - verify_certs=self.verify_certs, - connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection, - ), - takes_self=True, - ), - ) - def upsert_vector( self, vector: list[float], diff --git a/griptape/drivers/vector/astradb_vector_store_driver.py b/griptape/drivers/vector/astradb_vector_store_driver.py index 029fa382d..1e8398809 100644 --- a/griptape/drivers/vector/astradb_vector_store_driver.py +++ b/griptape/drivers/vector/astradb_vector_store_driver.py @@ -6,10 +6,11 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from astrapy import Collection - from astrapy.authentication import TokenProvider + import astrapy + import astrapy.authentication @define @@ -26,33 +27,37 @@ class AstraDbVectorStoreDriver(BaseVectorStoreDriver): It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values. astra_db_namespace: optional specification of the namespace (in the Astra database) for the data. *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store. + caller_name: the name of the caller for the Astra DB client. Defaults to "griptape". + client: an instance of `astrapy.DataAPIClient` for the Astra DB. + collection: an instance of `astrapy.Collection` for the Astra DB. """ api_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False}) + token: Optional[str | astrapy.authentication.TokenProvider] = field( + kw_only=True, default=None, metadata={"serializable": False} + ) collection_name: str = field(kw_only=True, metadata={"serializable": True}) environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - - collection: Collection = field(init=False) - - def __attrs_post_init__(self) -> None: - astrapy = import_optional_dependency("astrapy") - self.collection = ( - astrapy.DataAPIClient( - caller_name="griptape", - environment=self.environment, - ) - .get_database( - self.api_endpoint, - token=self.token, - namespace=self.astra_db_namespace, - ) - .get_collection( - name=self.collection_name, - ) + caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False}) + _client: astrapy.DataAPIClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + _collection: astrapy.Collection = field( + default=None, kw_only=True, alias="collection", metadata={"serializable": False} + ) + + @lazy_property() + def client(self) -> astrapy.DataAPIClient: + return import_optional_dependency("astrapy").DataAPIClient( + caller_name=self.caller_name, + environment=self.environment, ) + @lazy_property() + def collection(self) -> astrapy.Collection: + return self.client.get_database( + self.api_endpoint, token=self.token, namespace=self.astra_db_namespace + ).get_collection(self.collection_name) + def delete_vector(self, vector_id: str) -> None: """Delete a vector from Astra DB store. diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index caab118b8..55c3692a1 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, Any, NoReturn, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape import utils from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import marqo @@ -21,20 +22,18 @@ class MarqoVectorStoreDriver(BaseVectorStoreDriver): Attributes: api_key: The API key for the Marqo API. url: The URL to the Marqo API. - mq: An optional Marqo client. Defaults to a new client with the given URL and API key. + client: An optional Marqo client. Defaults to a new client with the given URL and API key. index: The name of the index to use. """ api_key: str = field(kw_only=True, metadata={"serializable": True}) url: str = field(kw_only=True, metadata={"serializable": True}) - mq: Optional[marqo.Client] = field( - default=Factory( - lambda self: import_optional_dependency("marqo").Client(self.url, api_key=self.api_key), - takes_self=True, - ), - kw_only=True, - ) index: str = field(kw_only=True, metadata={"serializable": True}) + _client: marqo.Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> marqo.Client: + return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key) def upsert_text( self, @@ -65,7 +64,7 @@ def upsert_text( if namespace: doc["namespace"] = namespace - response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description"]) + response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] else: @@ -102,7 +101,7 @@ def upsert_text_artifact( "namespace": namespace, } - response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"]) + response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] else: @@ -118,7 +117,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti Returns: The loaded Entry if found, otherwise None. """ - result = self.mq.index(self.index).get_document(document_id=vector_id, expose_facets=True) + result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True) if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0: return BaseVectorStoreDriver.Entry( @@ -141,15 +140,15 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto filter_string = f"namespace:{namespace}" if namespace else None if filter_string is not None: - results = self.mq.index(self.index).search("", limit=10000, filter_string=filter_string) + results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string) else: - results = self.mq.index(self.index).search("", limit=10000) + results = self.client.index(self.index).search("", limit=10000) # get all _id's from search results ids = [r["_id"] for r in results["hits"]] # get documents corresponding to the ids - documents = self.mq.index(self.index).get_documents(document_ids=ids, expose_facets=True) + documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True) # for each document, if it's found, create an Entry object entries = [] @@ -195,11 +194,12 @@ def query( "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs - results = self.mq.index(self.index).search(query, **params) + results = self.client.index(self.index).search(query, **params) if include_vectors: results["hits"] = [ - {**r, **self.mq.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"] + {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} + for r in results["hits"] ] return [ @@ -218,7 +218,7 @@ def delete_index(self, name: str) -> dict[str, Any]: Args: name: The name of the index to delete. """ - return self.mq.delete_index(name) + return self.client.delete_index(name) def get_indexes(self) -> list[str]: """Get a list of all indexes in the Marqo client. @@ -226,7 +226,7 @@ def get_indexes(self) -> list[str]: Returns: The list of all indexes. """ - return [index["index"] for index in self.mq.get_indexes()["results"]] + return [index["index"] for index in self.client.get_indexes()["results"]] def upsert_vector( self, diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index bc3f1e22f..a6f32620a 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from pymongo import MongoClient @@ -37,12 +38,11 @@ class MongoDbAtlasVectorStoreDriver(BaseVectorStoreDriver): kw_only=True, metadata={"serializable": True}, ) # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields - client: MongoClient = field( - default=Factory( - lambda self: import_optional_dependency("pymongo").MongoClient(self.connection_string), - takes_self=True, - ), - ) + _client: MongoClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> MongoClient: + return import_optional_dependency("pymongo").MongoClient(self.connection_string) def get_collection(self) -> Collection: """Returns the MongoDB Collection instance for the specified database and collection name.""" diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index cf944116a..5f247f6db 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -3,11 +3,12 @@ import logging from typing import TYPE_CHECKING, NoReturn, Optional -from attrs import Factory, define, field +from attrs import define, field from griptape import utils from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from opensearchpy import OpenSearch @@ -32,19 +33,19 @@ class OpenSearchVectorStoreDriver(BaseVectorStoreDriver): use_ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) verify_certs: bool = field(default=True, kw_only=True, metadata={"serializable": True}) index_name: str = field(kw_only=True, metadata={"serializable": True}) - - client: OpenSearch = field( - default=Factory( - lambda self: import_optional_dependency("opensearchpy").OpenSearch( - hosts=[{"host": self.host, "port": self.port}], - http_auth=self.http_auth, - use_ssl=self.use_ssl, - verify_certs=self.verify_certs, - connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection, - ), - takes_self=True, - ), - ) + _client: OpenSearch = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> OpenSearch: + opensearchpy = import_optional_dependency("opensearchpy") + + return opensearchpy.OpenSearch( + hosts=[{"host": self.host, "port": self.port}], + http_auth=self.http_auth, + use_ssl=self.use_ssl, + verify_certs=self.verify_certs, + connection_class=opensearchpy.RequestsHttpConnection, + ) def upsert_vector( self, diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index 30f437c7e..c1a6bef06 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -9,9 +9,10 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from sqlalchemy.engine import Engine + import sqlalchemy @define @@ -27,14 +28,14 @@ class PgVectorVectorStoreDriver(BaseVectorStoreDriver): connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) - engine: Optional[Engine] = field(default=None, kw_only=True) table_name: str = field(kw_only=True, metadata={"serializable": True}) _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) + _engine: sqlalchemy.Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) @connection_string.validator # pyright: ignore[reportAttributeAccessIssue] def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. - if self.engine is not None: + if self._engine is not None: return # If an engine is not provided, a connection string is required. @@ -44,22 +45,11 @@ def validate_connection_string(self, _: Attribute, connection_string: Optional[s if not connection_string.startswith("postgresql://"): raise ValueError("The connection string must describe a Postgres database connection") - @engine.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_engine(self, _: Attribute, engine: Optional[Engine]) -> None: - # If a connection string is provided, an engine does not need to be provided. - if self.connection_string is not None: - return - - # If a connection string is not provided, an engine is required. - if engine is None: - raise ValueError("An engine or connection string is required") - - def __attrs_post_init__(self) -> None: - if self.engine is None: - if self.connection_string is None: - raise ValueError("An engine or connection string is required") - sqlalchemy = import_optional_dependency("sqlalchemy") - self.engine = sqlalchemy.create_engine(self.connection_string, **self.create_engine_params) + @lazy_property() + def engine(self) -> sqlalchemy.Engine: + return import_optional_dependency("sqlalchemy").create_engine( + self.connection_string, **self.create_engine_params + ) def setup( self, diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index a3a132ab3..500b090f5 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -6,6 +6,7 @@ from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency, str_to_hash +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: import pinecone @@ -17,16 +18,20 @@ class PineconeVectorStoreDriver(BaseVectorStoreDriver): index_name: str = field(kw_only=True, metadata={"serializable": True}) environment: str = field(kw_only=True, metadata={"serializable": True}) project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - index: pinecone.Index = field(init=False) + _client: pinecone.Pinecone = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + _index: pinecone.Index = field(default=None, kw_only=True, alias="index", metadata={"serializable": False}) - def __attrs_post_init__(self) -> None: - pinecone = import_optional_dependency("pinecone").Pinecone( + @lazy_property() + def client(self) -> pinecone.Pinecone: + return import_optional_dependency("pinecone").Pinecone( api_key=self.api_key, environment=self.environment, project_name=self.project_name, ) - self.index = pinecone.Index(self.index_name) + @lazy_property() + def index(self) -> pinecone.Index: + return self.client.get_index(self.index_name) def upsert_vector( self, diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index 154e54af7..79cf64f37 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -2,12 +2,17 @@ import logging import uuid -from typing import Optional +from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property + +if TYPE_CHECKING: + from qdrant_client import QdrantClient + DEFAULT_DISTANCE = "Cosine" CONTENT_PAYLOAD_KEY = "data" @@ -56,9 +61,11 @@ class QdrantVectorStoreDriver(BaseVectorStoreDriver): collection_name: str = field(kw_only=True, metadata={"serializable": True}) vector_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True}) + _client: QdrantClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - def __attrs_post_init__(self) -> None: - self.client = import_optional_dependency("qdrant_client").QdrantClient( + @lazy_property() + def client(self) -> QdrantClient: + return import_optional_dependency("qdrant_client").QdrantClient( location=self.location, url=self.url, host=self.host, diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index 0abf2c985..d220878f3 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, NoReturn, Optional import numpy as np -from attrs import Factory, define, field +from attrs import define, field from griptape.drivers import BaseVectorStoreDriver from griptape.utils import import_optional_dependency, str_to_hash +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from redis import Redis @@ -33,19 +34,17 @@ class RedisVectorStoreDriver(BaseVectorStoreDriver): db: int = field(kw_only=True, default=0, metadata={"serializable": True}) password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) index: str = field(kw_only=True, metadata={"serializable": True}) - - client: Redis = field( - default=Factory( - lambda self: import_optional_dependency("redis").Redis( - host=self.host, - port=self.port, - db=self.db, - password=self.password, - decode_responses=False, - ), - takes_self=True, - ), - ) + _client: Redis = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> Redis: + return import_optional_dependency("redis").Redis( + host=self.host, + port=self.port, + db=self.db, + password=self.password, + decode_responses=False, + ) def upsert_vector( self, diff --git a/griptape/drivers/web_search/duck_duck_go_web_search_driver.py b/griptape/drivers/web_search/duck_duck_go_web_search_driver.py index b67e81f35..96891c2d4 100644 --- a/griptape/drivers/web_search/duck_duck_go_web_search_driver.py +++ b/griptape/drivers/web_search/duck_duck_go_web_search_driver.py @@ -3,11 +3,12 @@ import json from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import ListArtifact, TextArtifact from griptape.drivers import BaseWebSearchDriver from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from duckduckgo_search import DDGS @@ -15,7 +16,11 @@ @define class DuckDuckGoWebSearchDriver(BaseWebSearchDriver): - client: DDGS = field(default=Factory(lambda: import_optional_dependency("duckduckgo_search").DDGS()), kw_only=True) + _client: DDGS = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> DDGS: + return import_optional_dependency("duckduckgo_search").DDGS() def search(self, query: str, **kwargs) -> ListArtifact: try: diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index dde3ae49a..4892cbf9b 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -165,6 +165,13 @@ def _resolve_types(cls, attrs_cls: type) -> None: if is_dependency_installed("google.generativeai") else Any, "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any, + "Anthropic": import_optional_dependency("anthropic").Anthropic + if is_dependency_installed("anthropic") + else Any, + "BedrockClient": import_optional_dependency("mypy_boto3_bedrock").BedrockClient + if is_dependency_installed("mypy_boto3_bedrock") + else Any, + "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, }, ) diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 87020bd96..144c09d75 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -2,10 +2,11 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from griptape.tokenizers import BaseTokenizer from griptape.utils import import_optional_dependency +from griptape.utils.decorators import lazy_property if TYPE_CHECKING: from google.generativeai import GenerativeModel @@ -17,16 +18,14 @@ class GoogleTokenizer(BaseTokenizer): MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"gemini": 2048} api_key: str = field(kw_only=True, metadata={"serializable": True}) - model_client: GenerativeModel = field( - default=Factory(lambda self: self._default_model_client(), takes_self=True), - kw_only=True, - ) + _client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) - def count_tokens(self, text: str) -> int: - return self.model_client.count_tokens(text).total_tokens - - def _default_model_client(self) -> GenerativeModel: + @lazy_property() + def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model) + + def count_tokens(self, text: str) -> int: + return self.client.count_tokens(text).total_tokens diff --git a/griptape/tools/aws_iam/tool.py b/griptape/tools/aws_iam/tool.py index 8d22dd3c9..6c1bed054 100644 --- a/griptape/tools/aws_iam/tool.py +++ b/griptape/tools/aws_iam/tool.py @@ -2,20 +2,24 @@ from typing import TYPE_CHECKING -from attrs import Factory, define, field +from attrs import define, field from schema import Literal, Schema from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact from griptape.tools import BaseAwsTool -from griptape.utils.decorators import activity +from griptape.utils.decorators import activity, lazy_property if TYPE_CHECKING: - from mypy_boto3_iam import Client + from mypy_boto3_iam import IAMClient @define class AwsIamTool(BaseAwsTool): - iam_client: Client = field(default=Factory(lambda self: self.session.client("iam"), takes_self=True), kw_only=True) + _client: IAMClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> IAMClient: + return self.session.client("iam") @activity( config={ @@ -33,7 +37,7 @@ class AwsIamTool(BaseAwsTool): ) def get_user_policy(self, params: dict) -> TextArtifact | ErrorArtifact: try: - policy = self.iam_client.get_user_policy( + policy = self.client.get_user_policy( UserName=params["values"]["user_name"], PolicyName=params["values"]["policy_name"], ) @@ -44,7 +48,7 @@ def get_user_policy(self, params: dict) -> TextArtifact | ErrorArtifact: @activity(config={"description": "Can be used to list AWS MFA Devices"}) def list_mfa_devices(self, _: dict) -> ListArtifact | ErrorArtifact: try: - devices = self.iam_client.list_mfa_devices() + devices = self.client.list_mfa_devices() return ListArtifact([TextArtifact(str(d)) for d in devices["MFADevices"]]) except Exception as e: return ErrorArtifact(f"error listing mfa devices: {e}") @@ -59,10 +63,10 @@ def list_mfa_devices(self, _: dict) -> ListArtifact | ErrorArtifact: ) def list_user_policies(self, params: dict) -> ListArtifact | ErrorArtifact: try: - policies = self.iam_client.list_user_policies(UserName=params["values"]["user_name"]) + policies = self.client.list_user_policies(UserName=params["values"]["user_name"]) policy_names = policies["PolicyNames"] - attached_policies = self.iam_client.list_attached_user_policies(UserName=params["values"]["user_name"]) + attached_policies = self.client.list_attached_user_policies(UserName=params["values"]["user_name"]) attached_policy_names = [ p["PolicyName"] for p in attached_policies["AttachedPolicies"] if "PolicyName" in p ] @@ -74,7 +78,7 @@ def list_user_policies(self, params: dict) -> ListArtifact | ErrorArtifact: @activity(config={"description": "Can be used to list AWS IAM users."}) def list_users(self, _: dict) -> ListArtifact | ErrorArtifact: try: - users = self.iam_client.list_users() + users = self.client.list_users() return ListArtifact([TextArtifact(str(u)) for u in users["Users"]]) except Exception as e: return ErrorArtifact(f"error listing s3 users: {e}") diff --git a/griptape/tools/aws_s3/tool.py b/griptape/tools/aws_s3/tool.py index 24d091d71..b352da2d5 100644 --- a/griptape/tools/aws_s3/tool.py +++ b/griptape/tools/aws_s3/tool.py @@ -3,20 +3,24 @@ import io from typing import TYPE_CHECKING, Any -from attrs import Factory, define, field +from attrs import define, field from schema import Literal, Schema from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.tools import BaseAwsTool -from griptape.utils.decorators import activity +from griptape.utils.decorators import activity, lazy_property if TYPE_CHECKING: - from mypy_boto3_s3 import Client + from mypy_boto3_s3 import S3Client @define class AwsS3Tool(BaseAwsTool): - s3_client: Client = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True) + _client: S3Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + + @lazy_property() + def client(self) -> S3Client: + return self.session.client("s3") @activity( config={ @@ -33,7 +37,7 @@ class AwsS3Tool(BaseAwsTool): ) def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact: try: - acl = self.s3_client.get_bucket_acl(Bucket=params["values"]["bucket_name"]) + acl = self.client.get_bucket_acl(Bucket=params["values"]["bucket_name"]) return TextArtifact(acl) except Exception as e: return ErrorArtifact(f"error getting bucket acl: {e}") @@ -48,7 +52,7 @@ def get_bucket_acl(self, params: dict) -> TextArtifact | ErrorArtifact: ) def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact: try: - policy = self.s3_client.get_bucket_policy(Bucket=params["values"]["bucket_name"]) + policy = self.client.get_bucket_policy(Bucket=params["values"]["bucket_name"]) return TextArtifact(policy) except Exception as e: return ErrorArtifact(f"error getting bucket policy: {e}") @@ -66,7 +70,7 @@ def get_bucket_policy(self, params: dict) -> TextArtifact | ErrorArtifact: ) def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact: try: - acl = self.s3_client.get_object_acl( + acl = self.client.get_object_acl( Bucket=params["values"]["bucket_name"], Key=params["values"]["object_key"], ) @@ -77,7 +81,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact: @activity(config={"description": "Can be used to list all AWS S3 buckets."}) def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact: try: - buckets = self.s3_client.list_buckets() + buckets = self.client.list_buckets() return ListArtifact([TextArtifact(str(b)) for b in buckets["Buckets"]]) except Exception as e: @@ -91,7 +95,7 @@ def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact: ) def list_objects(self, params: dict) -> ListArtifact | ErrorArtifact: try: - objects = self.s3_client.list_objects_v2(Bucket=params["values"]["bucket_name"]) + objects = self.client.list_objects_v2(Bucket=params["values"]["bucket_name"]) if "Contents" not in objects: return ErrorArtifact("no objects found in the bucket") @@ -192,7 +196,7 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact: artifacts = [] for object_info in objects: try: - obj = self.s3_client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"]) + obj = self.client.get_object(Bucket=object_info["bucket_name"], Key=object_info["object_key"]) content = obj["Body"].read() artifacts.append(BlobArtifact(content, name=object_info["object_key"])) @@ -203,9 +207,9 @@ def download_objects(self, params: dict) -> ListArtifact | ErrorArtifact: return ListArtifact(artifacts) def _upload_object(self, bucket_name: str, object_name: str, value: Any) -> None: - self.s3_client.create_bucket(Bucket=bucket_name) + self.client.create_bucket(Bucket=bucket_name) - self.s3_client.upload_fileobj( + self.client.upload_fileobj( Fileobj=io.BytesIO(value.encode() if isinstance(value, str) else value), Bucket=bucket_name, Key=object_name, diff --git a/poetry.lock b/poetry.lock index 68d13fd05..d04582db9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -380,10 +380,14 @@ files = [ [package.dependencies] botocore-stubs = "*" mypy-boto3-bedrock = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"bedrock\""} +mypy-boto3-dynamodb = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"dynamodb\""} mypy-boto3-iam = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"iam\""} +mypy-boto3-iot-data = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"iot-data\""} mypy-boto3-opensearch = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"opensearch\""} +mypy-boto3-redshift-data = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"redshift-data\""} mypy-boto3-s3 = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"s3\""} mypy-boto3-sagemaker = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"sagemaker\""} +mypy-boto3-sqs = {version = ">=1.35.0,<1.36.0", optional = true, markers = "extra == \"sqs\""} types-s3transfer = "*" typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} @@ -3387,6 +3391,20 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} +[[package]] +name = "mypy-boto3-dynamodb" +version = "1.35.24" +description = "Type annotations for boto3.DynamoDB 1.35.24 service generated with mypy-boto3-builder 8.1.1" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy_boto3_dynamodb-1.35.24-py3-none-any.whl", hash = "sha256:022859543c5314f14fb03ef4e445e34b97b9bc0cecb003c14c10943a2eaa3ff7"}, + {file = "mypy_boto3_dynamodb-1.35.24.tar.gz", hash = "sha256:55bf897a1d0e354579edb05001f4bc4f472b9452badd9db24876c31bdf3f72a1"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-boto3-iam" version = "1.35.0" @@ -3401,6 +3419,20 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} +[[package]] +name = "mypy-boto3-iot-data" +version = "1.35.0" +description = "Type annotations for boto3.IoTDataPlane 1.35.0 service generated with mypy-boto3-builder 7.26.0" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy_boto3_iot_data-1.35.0-py3-none-any.whl", hash = "sha256:1f442679a71f22a82b0436ee4f71c06104a9ed722aa71c6800fd93bd345cfc03"}, + {file = "mypy_boto3_iot_data-1.35.0.tar.gz", hash = "sha256:e83cbbd948bc388ed139d2820442af1d319ca37dce708df44295c4acfcfb30f8"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-boto3-opensearch" version = "1.35.0" @@ -3415,6 +3447,20 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} +[[package]] +name = "mypy-boto3-redshift-data" +version = "1.35.10" +description = "Type annotations for boto3.RedshiftDataAPIService 1.35.10 service generated with mypy-boto3-builder 7.26.1" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy_boto3_redshift_data-1.35.10-py3-none-any.whl", hash = "sha256:1d37d8453c4f3e6b688703a91316729ee2dcaec101326c4f58658d8526d5fc09"}, + {file = "mypy_boto3_redshift_data-1.35.10.tar.gz", hash = "sha256:2cfe518ef3027c2b050facffd2621924458ddf2fb3df9699cdba33e8a6859594"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} + [[package]] name = "mypy-boto3-s3" version = "1.35.2" @@ -3443,6 +3489,20 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} +[[package]] +name = "mypy-boto3-sqs" +version = "1.35.0" +description = "Type annotations for boto3.SQS 1.35.0 service generated with mypy-boto3-builder 7.26.0" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy_boto3_sqs-1.35.0-py3-none-any.whl", hash = "sha256:9fd6e622ed231c06f7542ba6f8f0eea92046cace24defa95d0d0ce04e7caee0c"}, + {file = "mypy_boto3_sqs-1.35.0.tar.gz", hash = "sha256:61752f1c2bf2efa3815f64d43c25b4a39dbdbd9e472ae48aa18d7c6d2a7a6eb8"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.12\""} + [[package]] name = "ndg-httpsclient" version = "0.5.1" @@ -6391,6 +6451,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -7024,4 +7089,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bb4af9c531d0029cb1baeca3a2e94566aaf6d7cb701a6dc07f5e9983bffd1285" +content-hash = "96cb1c9cb807d112d5b6fdae19b99fad98de4f82ab73ae8b24d313dd5d7ff773" diff --git a/pyproject.toml b/pyproject.toml index 591ebc3cd..d5086386e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,7 +228,7 @@ optional = true ruff = "^0.6.0" pyright = "^1.1.376" pre-commit = "^3.7.1" -boto3-stubs = {extras = ["bedrock", "iam", "opensearch", "s3", "sagemaker"], version = "^1.34.105"} +boto3-stubs = {extras = ["bedrock", "iam", "opensearch", "s3", "sagemaker", "sqs", "iot-data", "dynamodb", "redshift-data"], version = "^1.34.105"} typos = "^1.22.9" diff --git a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py index 50856c0da..4168f762d 100644 --- a/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_pusher_event_listener_driver.py @@ -9,11 +9,11 @@ class TestPusherEventListenerDriver: @pytest.fixture(autouse=True) def mock_post(self, mocker): - mock_pusher_client = mocker.patch("pusher.Pusher") - mock_pusher_client.return_value.trigger.return_value = Mock() - mock_pusher_client.return_value.trigger_batch.return_value = Mock() + mock_client = mocker.patch("pusher.Pusher") + mock_client.return_value.trigger.return_value = Mock() + mock_client.return_value.trigger_batch.return_value = Mock() - return mock_pusher_client + return mock_client @pytest.fixture() def driver(self): @@ -33,12 +33,12 @@ def test_try_publish_event_payload(self, driver): data = MockEvent().to_dict() driver.try_publish_event_payload(data) - driver.pusher_client.trigger.assert_called_with(channels="test-channel", event_name="test-event", data=data) + driver.client.trigger.assert_called_with(channels="test-channel", event_name="test-event", data=data) def test_try_publish_event_payload_batch(self, driver): data = [MockEvent().to_dict() for _ in range(3)] driver.try_publish_event_payload_batch(data) - driver.pusher_client.trigger_batch.assert_called_with( + driver.client.trigger_batch.assert_called_with( [{"channel": "test-channel", "name": "test-event", "data": data[i]} for i in range(3)] ) diff --git a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py index 05e669b66..dae6f695b 100644 --- a/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_amazon_bedrock_stable_diffusion_image_generation_driver.py @@ -8,13 +8,13 @@ class TestAmazonBedrockImageGenerationDriver: @pytest.fixture() - def bedrock_client(self): + def client(self): return Mock() @pytest.fixture() - def session(self, bedrock_client): + def session(self, client): session = Mock() - session.client.return_value = bedrock_client + session.client.return_value = client return session @@ -40,7 +40,7 @@ def test_init_requires_image_generation_model_driver(self, session): AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1") # pyright: ignore[reportCallIssue] def test_try_text_to_image(self, driver): - driver.bedrock_client.invoke_model.return_value = { + driver.client.invoke_model.return_value = { "body": io.BytesIO( b"""{ "artifacts": [ diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py index 9493ab23d..66b23d0c3 100644 --- a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py @@ -9,13 +9,13 @@ class TestAmazonBedrockImageQueryDriver: @pytest.fixture() - def bedrock_client(self, mocker): + def client(self, mocker): return Mock() @pytest.fixture() - def session(self, bedrock_client): + def session(self, client): session = Mock() - session.client.return_value = bedrock_client + session.client.return_value = client return session @@ -35,7 +35,7 @@ def test_init(self, image_query_driver): assert image_query_driver def test_try_query(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} + image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} text_artifact = image_query_driver.try_query( "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")] @@ -44,7 +44,7 @@ def test_try_query(self, image_query_driver): assert text_artifact.value == "content" def test_try_query_no_body(self, image_query_driver): - image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"")} + image_query_driver.client.invoke_model.return_value = {"body": io.BytesIO(b"")} with pytest.raises(ValueError): image_query_driver.try_query( diff --git a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py index 254a2b3a1..2d824e1c2 100644 --- a/tests/unit/drivers/vector/test_marqo_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_marqo_vector_store_driver.py @@ -86,7 +86,7 @@ def driver(self, mock_marqo): api_key="foobar", url="http://localhost:8000", index="test", - mq=mock_marqo, + client=mock_marqo, embedding_driver=MockEmbeddingDriver(), ) diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 0726a0c7e..a963fb370 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -7,7 +7,7 @@ class TestPineconeVectorStorageDriver: @pytest.fixture(autouse=True) - def _mock_pinecone(self, mocker): + def mock_index(self, mocker): # Create a fake response fake_query_response = { "matches": [{"id": "foo", "values": [0, 1, 0], "score": 42, "metadata": {"foo": "bar"}}], @@ -15,14 +15,21 @@ def _mock_pinecone(self, mocker): } mock_client = mocker.patch("pinecone.Pinecone") - mock_client().Index().upsert.return_value = None - mock_client().Index().query.return_value = fake_query_response - mock_client().create_index.return_value = None + mock_index = mock_client().Index() + mock_index.upsert.return_value = None + mock_index.query.return_value = fake_query_response + mock_index.create_index.return_value = None + + return mock_index @pytest.fixture() - def driver(self): + def driver(self, mock_index): return PineconeVectorStoreDriver( - api_key="foobar", index_name="test", environment="test", embedding_driver=MockEmbeddingDriver() + api_key="foobar", + index_name="test", + environment="test", + embedding_driver=MockEmbeddingDriver(), + index=mock_index, ) def test_upsert_text_artifact(self, driver): diff --git a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py index ffb359953..3c14f2396 100644 --- a/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_qdrant_vector_store_driver.py @@ -37,30 +37,6 @@ def driver(self, embedding_driver, mocker): embedding_driver=embedding_driver, ) - def test_attrs_post_init(self, driver): - with patch("griptape.drivers.vector.qdrant_vector_store_driver.import_optional_dependency") as mock_import: - mock_qdrant_client = MagicMock() - mock_import.return_value.QdrantClient.return_value = mock_qdrant_client - - driver.__attrs_post_init__() - - mock_import.assert_called_once_with("qdrant_client") - mock_import.return_value.QdrantClient.assert_called_once_with( - location=driver.location, - url=driver.url, - host=driver.host, - path=driver.path, - port=driver.port, - prefer_grpc=driver.prefer_grpc, - grpc_port=driver.grpc_port, - api_key=driver.api_key, - https=driver.https, - prefix=driver.prefix, - force_disable_check_same_thread=driver.force_disable_check_same_thread, - timeout=driver.timeout, - ) - assert driver.client == mock_qdrant_client - def test_delete_vector(self, driver): vector_id = "test_vector_id"