Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Sep 18, 2024
1 parent c015885 commit 12ae245
Show file tree
Hide file tree
Showing 53 changed files with 503 additions and 475 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Optional

import openai
from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -17,12 +18,11 @@ class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
additional_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -26,7 +27,7 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
`search_query` when querying your vector DB to find relevant documents.
session: Optionally provide custom `boto3.Session`.
tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "cohere.embed-english-v3"
Expand All @@ -38,15 +39,16 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="*/*",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -23,7 +24,7 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
model: Embedding model name. Defaults to DEFAULT_MODEL.
tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
session: Optionally provide custom `boto3.Session`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
client: Optionally provide custom `bedrock-runtime` client.
"""

DEFAULT_MODEL = "amazon.titan-embed-text-v1"
Expand All @@ -34,15 +35,16 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
kw_only=True,
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"inputText": chunk}

response = self.bedrock_client.invoke_model(
response = self.client.invoke_model(
body=json.dumps(payload),
modelId=self.model,
accept="application/json",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -15,18 +16,19 @@
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sagemaker_client: Any = field(
default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
kw_only=True,
)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}

endpoint_response = self.sagemaker_client.invoke_endpoint(
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
Expand Down
28 changes: 14 additions & 14 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import OpenAiEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand Down Expand Up @@ -40,17 +41,16 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
)
_client: openai.AzureOpenAI = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.AzureOpenAI:
return openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
11 changes: 6 additions & 5 deletions griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from cohere import Client
Expand All @@ -27,16 +28,16 @@ class CohereEmbeddingDriver(BaseEmbeddingDriver):
DEFAULT_MODEL = "models/embedding-001"

api_key: str = field(kw_only=True, metadata={"serializable": False})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
input_type: str = field(kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, metadata={"serializable": False})
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)

input_type: str = field(kw_only=True, metadata={"serializable": True})
@lazy_property()
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)
Expand Down
21 changes: 10 additions & 11 deletions griptape/drivers/embedding/huggingface_hub_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from huggingface_hub import InferenceClient
Expand All @@ -22,16 +23,14 @@ class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
"""

api_token: str = field(kw_only=True, metadata={"serializable": True})
client: InferenceClient = field(
default=Factory(
lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
),
takes_self=True,
),
kw_only=True,
)
_client: InferenceClient = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> InferenceClient:
return import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
Expand Down
12 changes: 7 additions & 5 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from ollama import Client
Expand All @@ -23,10 +24,11 @@ class OllamaEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)
_client: Client = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("ollama").Client(host=self.host)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
12 changes: 6 additions & 6 deletions griptape/drivers/embedding/openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -33,16 +34,15 @@ class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
),
)
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
Expand Down
12 changes: 6 additions & 6 deletions griptape/drivers/embedding/voyageai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import VoyageAiTokenizer
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property


@define
Expand All @@ -25,17 +26,16 @@ class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key),
takes_self=True,
),
)
tokenizer: VoyageAiTokenizer = field(
default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True),
kw_only=True,
)
input_type: str = field(default="document", kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return import_optional_dependency("voyageai").Client(api_key=self.api_key)

def try_embed_chunk(self, chunk: str) -> list[float]:
return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -16,15 +17,19 @@
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
queue_url: str = field(kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True))
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("sqs")

def try_publish_event_payload(self, event_payload: dict) -> None:
self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))
self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
entries = [
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
for event_payload in event_payload_batch
]

self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver
from griptape.utils import import_optional_dependency
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
import boto3
Expand All @@ -17,10 +18,14 @@ class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):
iot_endpoint: str = field(kw_only=True)
topic: str = field(kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True))
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Any:
return self.session.client("iot-data")

def try_publish_event_payload(self, event_payload: dict) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))
self.client.publish(topic=self.topic, payload=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
Loading

0 comments on commit 12ae245

Please sign in to comment.