Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize client usage #1173

Merged
merged 9 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## Added
- Parameter `pipeline_task` on `HuggingFacePipelinePromptDriver` for creating different types of `Pipeline`s.

### Changed
- **BREAKING**: Renamed parameters on several classes to `client`:
- `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockCohereEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockTitanEmbeddingDriver`.
- `bedrock_client` on `AmazonBedrockImageGenerationDriver`.
- `bedrock_client` on `AmazonBedrockImageQueryDriver`.
- `bedrock_client` on `AmazonBedrockPromptDriver`.
- `sagemaker_client` on `AmazonSageMakerJumpstartEmbeddingDriver`.
- `sagemaker_client` on `AmazonSageMakerJumpstartPromptDriver`.
- `sqs_client` on `AmazonSqsEventListenerDriver`.
- `iotdata_client` on `AwsIotCoreEventListenerDriver`.
- `s3_client` on `AmazonS3FileManagerDriver`.
- `s3_client` on `AwsS3Tool`.
- `iam_client` on `AwsIamTool`.
- `pusher_client` on `PusherEventListenerDriver`.
- `mq` on `MarqoVectorStoreDriver`.
- `model_client` on `GooglePromptDriver`.
- `model_client` on `GoogleTokenizer`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`.
- Several places where API clients are initialized are now lazy loaded.


## [0.32.0] - 2024-09-17

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Optional

import openai
from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -17,12 +18,11 @@
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

Check warning on line 25 in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/openai_audio_transcription_driver.py#L25

Added line #L25 was not covered by tests

def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
additional_params = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand All @@ -26,7 +28,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
`search_query` when querying your vector DB to find relevant documents.
session: Optionally provide custom `boto3.Session`.
tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "cohere.embed-english-v3"
Expand All @@ -38,15 +40,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="*/*",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand All @@ -23,7 +25,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
model: Embedding model name. Defaults to DEFAULT_MODEL.
tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
session: Optionally provide custom `boto3.Session`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "amazon.titan-embed-text-v1"
Expand All @@ -34,15 +36,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: BedrockClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> BedrockClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"inputText": chunk}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="application/json",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
from mypy_boto3_sagemaker import SageMakerClient


@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
kw_only=True,
)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: SageMakerClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.sagemaker_client.invoke_endpoint(
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
Expand Down
28 changes: 14 additions & 14 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import OpenAiEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand Down Expand Up @@ -40,17 +41,16 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
)
_client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.AzureOpenAI:
return openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
11 changes: 6 additions & 5 deletions griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from cohere import Client
Expand All @@ -27,16 +28,16 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver):
DEFAULT_MODEL = "models/embedding-001"

api_key: str = field(kw_only=True, metadata={"serializable": False})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
input_type: str = field(kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)

input_type: str = field(kw_only=True, metadata={"serializable": True})
@lazy_property()
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)
Expand Down
21 changes: 10 additions & 11 deletions griptape/drivers/embedding/huggingface_hub_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from huggingface_hub import InferenceClient
Expand All @@ -22,16 +23,14 @@
"""

api_token: str = field(kw_only=True, metadata={"serializable": True})
client: InferenceClient = field(
default=Factory(
lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
),
takes_self=True,
),
kw_only=True,
)
_client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> InferenceClient:
return import_optional_dependency("huggingface_hub").InferenceClient(

Check warning on line 30 in griptape/drivers/embedding/huggingface_hub_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/huggingface_hub_embedding_driver.py#L30

Added line #L30 was not covered by tests
model=self.model,
token=self.api_token,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
Expand Down
12 changes: 7 additions & 5 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from ollama import Client
Expand All @@ -23,10 +24,11 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("ollama").Client(host=self.host)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
12 changes: 6 additions & 6 deletions griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -33,16 +34,15 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
Expand Down
Loading
Loading