Skip to content

Commit

Permalink
Update client usage
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Sep 18, 2024
1 parent c015885 commit 2858f4e
Show file tree
Hide file tree
Showing 54 changed files with 524 additions and 475 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageGenerationDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockImageQueryDriver` to `client`.
- **BREAKING**: Renamed parameter `bedrock_client` on `AmazonBedrockPromptDriver` to `client`.
- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver` to `client`.
- **BREAKING**: Renamed parameter `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver` to `client`.
- **BREAKING**: Renamed parameter `sqs_client` on `AmazonSqsEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `iotdata_client` on `AwsIotCoreEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `s3_client` on `AmazonS3FileManagerDriver` to `client`.
- **BREAKING**: Renamed parameter `s3_client` on `AwsS3Tool` to `client`.
- **BREAKING**: Renamed parameter `pusher_client` on `PusherEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `mq` on `MarqoVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GoogleTokenizer` to `client`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `text_generation_pipeline`.
- **BREAKING**: Renamed parameter `engine` on `PgVectorVectorStoreDriver` to `sqlalchemy_engine`.
- Several places where API clients are initialized are now lazy loaded.


## [0.32.0] - 2024-09-17

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Optional

import openai
from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -17,12 +18,11 @@ class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
additional_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -26,7 +27,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
`search_query` when querying your vector DB to find relevant documents.
session: Optionally provide custom `boto3.Session`.
tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "cohere.embed-english-v3"
Expand All @@ -38,15 +39,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="*/*",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -23,7 +24,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
model: Embedding model name. Defaults to DEFAULT_MODEL.
tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
session: Optionally provide custom `boto3.Session`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "amazon.titan-embed-text-v1"
Expand All @@ -34,15 +35,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"inputText": chunk}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="application/json",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -15,18 +16,19 @@
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
kw_only=True,
)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.sagemaker_client.invoke_endpoint(
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
Expand Down
28 changes: 14 additions & 14 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import OpenAiEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand Down Expand Up @@ -40,17 +41,16 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
)
_client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.AzureOpenAI:
return openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
11 changes: 6 additions & 5 deletions griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from cohere import Client
Expand All @@ -27,16 +28,16 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver):
DEFAULT_MODEL = "models/embedding-001"

api_key: str = field(kw_only=True, metadata={"serializable": False})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
input_type: str = field(kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, metadata={"serializable": False})
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)

input_type: str = field(kw_only=True, metadata={"serializable": True})
@lazy_property()
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)
Expand Down
21 changes: 10 additions & 11 deletions griptape/drivers/embedding/huggingface_hub_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from huggingface_hub import InferenceClient
Expand All @@ -22,16 +23,14 @@ class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
"""

api_token: str = field(kw_only=True, metadata={"serializable": True})
client: InferenceClient = field(
default=Factory(
lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
),
takes_self=True,
),
kw_only=True,
)
_client: InferenceClient = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> InferenceClient:
return import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
Expand Down
12 changes: 7 additions & 5 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from ollama import Client
Expand All @@ -23,10 +24,11 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)
_client: Client = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("ollama").Client(host=self.host)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
12 changes: 6 additions & 6 deletions griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -33,16 +34,15 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
Expand Down
12 changes: 6 additions & 6 deletions griptape/drivers/embedding/voyageai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import VoyageAiTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -25,17 +26,16 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key),
takes_self=True,
),
)
tokenizer: VoyageAiTokenizer = field(
default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True),
kw_only=True,
)
input_type: str = field(default="document", kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return import_optional_dependency("voyageai").Client(api_key=self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -16,15 +17,19 @@
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
queue_url: str = field(kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True))
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("sqs")

def try_publish_event_payload(self, event_payload: dict) -> None:
self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))
self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
entries = [
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
for event_payload in event_payload_batch
]

self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
Loading

0 comments on commit 2858f4e

Please sign in to comment.