Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Sep 16, 2024
1 parent 8c407de commit 6e81d41
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 83 deletions.
9 changes: 3 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed parameter `s3_client` on `AmazonS3FileManagerDriver` to `client`.
- **BREAKING**: Renamed parameter `s3_client` on `AwsS3Tool` to `client`.
- **BREAKING**: Renamed parameter `pusher_client` on `PusherEventListenerDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `collection` on `AstraDbVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `mq` on `MarqoVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `engine` on `PgVectorVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `index` on `PineconeVectorStoreDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`.
- **BREAKING**: Renamed parameter `model_client` on `GoogleTokenizer` to `client`.
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `text_generation_pipeline`.
- Updated `JsonArtifact` value converter to properly handle more types.
- `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`.
- Removed `__add__` method from `BaseArtifact`, implemented it where necessary.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.
- The `client` parameter on `Driver`s that use a client are now lazily initialized.
- Several places where API clients are initialized are now lazy loaded.

## [0.31.0] - 2024-09-03

Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver):
),
kw_only=True,
)
_client: TextGenerationPipeline = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
_text_generation_pipeline: TextGenerationPipeline = field(
default=None, kw_only=True, alias="text_generation_pipeline", metadata={"serializable": False}
)

@lazy_property()
def client(self) -> TextGenerationPipeline:
def text_generation_pipeline(self) -> TextGenerationPipeline:
return import_optional_dependency("transformers").pipeline(
"text-generation",
model=self.model,
Expand All @@ -53,7 +53,7 @@ def client(self) -> TextGenerationPipeline:
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self._prompt_stack_to_messages(prompt_stack)

result = self.client(
result = self.text_generation_pipeline(
messages,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
Expand Down
52 changes: 27 additions & 25 deletions griptape/drivers/vector/astradb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from astrapy import Collection
from astrapy.authentication import TokenProvider
import astrapy
import astrapy.authentication


@define
Expand All @@ -27,33 +27,35 @@ class AstraDbVectorStoreDriver(BaseVectorStoreDriver):
It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
*Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
caller_name: the name of the caller for the Astra DB client. Defaults to "griptape".
client: an instance of `astrapy.DataAPIClient` for the Astra DB.
collection: an instance of `astrapy.Collection` for the Astra DB.
"""

api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False})
token: Optional[str | astrapy.authentication.TokenProvider] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
collection_name: str = field(kw_only=True, metadata={"serializable": True})
environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: Collection = field(default=None, kw_only=True, metadata={"serializable": False})
caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False})
_client: astrapy.DataAPIClient = field(default=None, kw_only=True, metadata={"serializable": False})
_collection: astrapy.Collection = field(default=None, kw_only=True, metadata={"serializable": False})

@lazy_property()
def client(self) -> Collection:
astrapy = import_optional_dependency("astrapy")
return (
astrapy.DataAPIClient(
caller_name="griptape",
environment=self.environment,
)
.get_database(
self.api_endpoint,
token=self.token,
namespace=self.astra_db_namespace,
)
.get_collection(
name=self.collection_name,
)
def client(self) -> astrapy.DataAPIClient:
return import_optional_dependency("astrapy").DataAPIClient(
caller_name=self.caller_name,
environment=self.environment,
)

@lazy_property()
def collection(self) -> astrapy.Collection:
return self.client.get_database(
self.api_endpoint, token=self.token, namespace=self.astra_db_namespace
).get_collection(self.collection_name)

def delete_vector(self, vector_id: str) -> None:
"""Delete a vector from Astra DB store.
Expand All @@ -63,7 +65,7 @@ def delete_vector(self, vector_id: str) -> None:
Args:
vector_id: ID of the vector to delete.
"""
self.client.delete_one({"_id": vector_id})
self.collection.delete_one({"_id": vector_id})

def upsert_vector(
self,
Expand Down Expand Up @@ -94,10 +96,10 @@ def upsert_vector(
if v is not None
}
if vector_id is not None:
self.client.find_one_and_replace({"_id": vector_id}, document, upsert=True)
self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
return vector_id
else:
insert_result = self.client.insert_one(document)
insert_result = self.collection.insert_one(document)
return insert_result.inserted_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
Expand All @@ -111,7 +113,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
"""
find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
match = self.client.find_one(filter=find_filter, projection={"*": 1})
match = self.collection.find_one(filter=find_filter, projection={"*": 1})
if match is not None:
return BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
Expand All @@ -133,7 +135,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
for match in self.client.find(filter=find_filter, projection={"*": 1})
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
Expand Down Expand Up @@ -166,7 +168,7 @@ def query(
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.client.find(
matches = self.collection.find(
filter=find_filter,
sort={"$vector": vector},
limit=ann_limit,
Expand Down
22 changes: 11 additions & 11 deletions griptape/drivers/vector/pgvector_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from sqlalchemy.engine import Engine
import sqlalchemy


@define
Expand All @@ -30,12 +30,12 @@ class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
table_name: str = field(kw_only=True, metadata={"serializable": True})
_model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))
_client: Engine = field(default=None, kw_only=True, metadata={"serializable": False})
_sqlalchemy_engine: sqlalchemy.Engine = field(default=None, kw_only=True, metadata={"serializable": False})

@connection_string.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None:
# If an engine is provided, the connection string is not used.
if self._client is not None:
if self._sqlalchemy_engine is not None:
return

# If an engine is not provided, a connection string is required.
Expand All @@ -46,7 +46,7 @@ def validate_connection_string(self, _: Attribute, connection_string: Optional[s
raise ValueError("The connection string must describe a Postgres database connection")

@lazy_property()
def client(self) -> Engine:
def sqlalchemy_engine(self) -> sqlalchemy.Engine:
return import_optional_dependency("sqlalchemy").create_engine(
self.connection_string, **self.create_engine_params
)
Expand All @@ -62,15 +62,15 @@ def setup(
sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql")

if install_uuid_extension:
with self.client.begin() as conn:
with self.sqlalchemy_engine.begin() as conn:
conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'))

if install_vector_extension:
with self.client.begin() as conn:
with self.sqlalchemy_engine.begin() as conn:
conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";'))

if create_schema:
self._model.metadata.create_all(self.client)
self._model.metadata.create_all(self.sqlalchemy_engine)

def upsert_vector(
self,
Expand All @@ -84,7 +84,7 @@ def upsert_vector(
"""Inserts or updates a vector in the collection."""
sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

with sqlalchemy_orm.Session(self.client) as session:
with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session:
obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

obj = session.merge(obj)
Expand All @@ -96,7 +96,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Base
"""Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

with sqlalchemy_orm.Session(self.client) as session:
with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session:
result = session.get(self._model, vector_id)

return BaseVectorStoreDriver.Entry(
Expand All @@ -110,7 +110,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
"""Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace."""
sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

with sqlalchemy_orm.Session(self.client) as session:
with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session:
query = session.query(self._model)
if namespace:
query = query.filter_by(namespace=namespace)
Expand Down Expand Up @@ -151,7 +151,7 @@ def query(

op = distance_metrics[distance_metric]

with sqlalchemy_orm.Session(self.client) as session:
with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session:
vector = self.embedding_driver.embed_string(query)

# The query should return both the vector and the distance metric score.
Expand Down
29 changes: 15 additions & 14 deletions griptape/drivers/vector/pinecone_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@ class PineconeVectorStoreDriver(BaseVectorStoreDriver):
index_name: str = field(kw_only=True, metadata={"serializable": True})
environment: str = field(kw_only=True, metadata={"serializable": True})
project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: pinecone.Index = field(default=None, kw_only=True, metadata={"serializable": False})
_client: pinecone.Pinecone = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
_index: pinecone.Index = field(default=None, kw_only=True, alias="index", metadata={"serializable": False})

@lazy_property()
def client(self) -> pinecone.Index:
return (
import_optional_dependency("pinecone")
.Pinecone(
api_key=self.api_key,
environment=self.environment,
project_name=self.project_name,
)
.Index(self.index_name)
def client(self) -> pinecone.Pinecone:
return import_optional_dependency("pinecone").Pinecone(
api_key=self.api_key,
environment=self.environment,
project_name=self.project_name,
)

@lazy_property()
def index(self) -> pinecone.Index:
return self.client.get_index(self.index_name)

def upsert_vector(
self,
vector: list[float],
Expand All @@ -44,12 +45,12 @@ def upsert_vector(

params: dict[str, Any] = {"namespace": namespace} | kwargs

self.client.upsert(vectors=[(vector_id, vector, meta)], **params)
self.index.upsert(vectors=[(vector_id, vector, meta)], **params)

return vector_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
result = self.client.fetch(ids=[vector_id], namespace=namespace).to_dict()
result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
vectors = list(result["vectors"].values())

if len(vectors) > 0:
Expand All @@ -69,7 +70,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
# all values from a namespace:
# https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5

results = self.client.query(
results = self.index.query(
vector=self.embedding_driver.embed_string(""),
top_k=10000,
include_metadata=True,
Expand Down Expand Up @@ -105,7 +106,7 @@ def query(
"include_metadata": include_metadata,
} | kwargs

results = self.client.query(vector=vector, **params)
results = self.index.query(vector=vector, **params)

return [
BaseVectorStoreDriver.Entry(
Expand Down
Loading

0 comments on commit 6e81d41

Please sign in to comment.