Skip to content

Commit

Permalink
Add DataStax Astra DB vector store driver (#1034)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefano Lottini <[email protected]>
  • Loading branch information
collindutter and hemidactylus authored Jul 31, 2024
1 parent 47212cf commit 3cb1fe7
Show file tree
Hide file tree
Showing 12 changed files with 1,175 additions and 437 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ jobs:
ZENROWS_API_KEY: ${{ secrets.INTEG_ZENROWS_API_KEY }}
QDRANT_CLUSTER_ENDPOINT: ${{ secrets.INTEG_QDRANT_CLUSTER_ENDPOINT }}
QDRANT_CLUSTER_API_KEY: ${{ secrets.INTEG_QDRANT_CLUSTER_API_KEY }}
ASTRA_DB_API_ENDPOINT: ${{ secrets.INTEG_ASTRA_DB_API_ENDPOINT }}
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.INTEG_ASTRA_DB_APPLICATION_TOKEN }}
services:
postgres:
image: ankane/pgvector:v0.5.0
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.

## [0.29.0] - 2024-07-30

### Added
Expand Down
78 changes: 78 additions & 0 deletions docs/examples/query-webpage-astra-db.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
The following example script ingests a Web page (a blog post),
stores its chunked contents on Astra DB through the Astra DB vector store driver,
and finally runs a RAG process to answer a question specific to the topic of the
Web page.

This script requires that a vector collection has been created in the Astra database
(with name `"griptape_test_collection"` and vector dimension matching the embedding being used, i.e. 1536 in this case).

_Note:_ Besides the [Astra DB](../griptape-framework/drivers/vector-store-drivers.md#astra-db) extra,
this example requires the `drivers-web-scraper-trafilatura`
Griptape extra to be installed as well.


```python
import os

from griptape.drivers import (
AstraDbVectorStoreDriver,
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
)
from griptape.engines.rag import RagEngine
from griptape.engines.rag.modules import (
PromptResponseRagModule,
VectorStoreRetrievalRagModule,
)
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.structures import Agent
from griptape.tools import RagClient, TaskMemoryClient


namespace = "datastax_blog"
input_blogpost = (
"www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop"
)

vector_store_driver = AstraDbVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(),
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
collection_name="griptape_test_collection",
astra_db_namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)

engine = RagEngine(
retrieval_stage=RetrievalRagStage(
retrieval_modules=[
VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver,
query_params={
"count": 2,
"namespace": namespace,
},
)
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")
)
)
)

vector_store_driver.upsert_text_artifacts(
{namespace: WebLoader(max_tokens=256).load(input_blogpost)}
)

vector_store_tool = RagClient(
description="A DataStax blog post",
rag_engine=engine,
)
agent = Agent(tools=[vector_store_tool, TaskMemoryClient(off_prompt=False)])
agent.run(
"What engine made possible to index such an amount of data, "
"and what kind of tuning was required?"
)
```
46 changes: 46 additions & 0 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,49 @@ values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```

### Astra DB

!!! info
This Driver requires the `drivers-vector-astra-db` [extra](../index.md#extras).

The AstraDbVectorStoreDriver supports [DataStax Astra DB](https://www.datastax.com/products/datastax-astra).

The following example shows how to store vector entries and query the information using the driver:

```python
import os
from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

# Astra DB secrets and connection parameters
api_endpoint = os.environ["ASTRA_DB_API_ENDPOINT"]
token = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
astra_db_namespace = os.environ.get("ASTRA_DB_KEYSPACE") # optional

# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = AstraDbVectorStoreDriver(
embedding_driver=embedding_driver,
api_endpoint=api_endpoint,
token=token,
collection_name="griptape_test_collection",
astra_db_namespace=astra_db_namespace, # optional
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Upsert Artifacts into the Vector Store Driver
[
vector_store_driver.upsert_text_artifact(a, namespace="griptape")
for a in artifacts
]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver
from .vector.dummy_vector_store_driver import DummyVectorStoreDriver
from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver
from .vector.astradb_vector_store_driver import AstraDbVectorStoreDriver
from .vector.griptape_cloud_knowledge_base_vector_store_driver import GriptapeCloudKnowledgeBaseVectorStoreDriver

from .sql.base_sql_driver import BaseSqlDriver
Expand Down Expand Up @@ -171,6 +172,7 @@
"AmazonOpenSearchVectorStoreDriver",
"PgVectorVectorStoreDriver",
"QdrantVectorStoreDriver",
"AstraDbVectorStoreDriver",
"DummyVectorStoreDriver",
"GriptapeCloudKnowledgeBaseVectorStoreDriver",
"BaseSqlDriver",
Expand Down
184 changes: 184 additions & 0 deletions griptape/drivers/vector/astradb_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import BaseVectorStoreDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from astrapy import Collection
from astrapy.authentication import TokenProvider


@define
class AstraDbVectorStoreDriver(BaseVectorStoreDriver):
"""A Vector Store Driver for Astra DB.
Attributes:
embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store
api_endpoint: the "API Endpoint" for the Astra DB instance.
token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted.
collection_name: the name of the collection on Astra DB. The collection must have been created beforehand,
and support vectors with a vector dimension matching the embeddings being used by this driver.
environment: the environment ("prod", "hcd", ...) hosting the target Data API.
It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
*Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
"""

api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False})
collection_name: str = field(kw_only=True, metadata={"serializable": True})
environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

collection: Collection = field(init=False)

def __attrs_post_init__(self) -> None:
astrapy = import_optional_dependency("astrapy")
self.collection = (
astrapy.DataAPIClient(
caller_name="griptape",
environment=self.environment,
)
.get_database(
self.api_endpoint,
token=self.token,
namespace=self.astra_db_namespace,
)
.get_collection(
name=self.collection_name,
)
)

def delete_vector(self, vector_id: str) -> None:
"""Delete a vector from Astra DB store.
The method succeeds regardless of whether a vector with the provided ID
was actually stored or not in the first place.
Args:
vector_id: ID of the vector to delete.
"""
self.collection.delete_one({"_id": vector_id})

def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs: Any,
) -> str:
"""Write a vector to the Astra DB store.
In case the provided ID exists already, an overwrite will take place.
Args:
vector: the vector to be upserted.
vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
namespace: a namespace (a grouping within the vector store) to assign the vector to.
meta: a metadata dictionary associated to the vector.
kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.
Returns:
the ID of the written vector (str).
"""
document = {
k: v
for k, v in {"$vector": vector, "_id": vector_id, "namespace": namespace, "meta": meta}.items()
if v is not None
}
if vector_id is not None:
self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
return vector_id
else:
insert_result = self.collection.insert_one(document)
return insert_result.inserted_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Load a single vector entry from the Astra DB store given its ID.
Args:
vector_id: the ID of the required vector.
namespace: a namespace, within the vector store, to constrain the search.
Returns:
The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
"""
find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
match = self.collection.find_one(filter=find_filter, projection={"*": 1})
if match is not None:
return BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
else:
return None

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Load entries from the Astra DB store.
Args:
namespace: a namespace, within the vector store, to constrain the search.
Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries.
"""
find_filter: dict[str, str] = {} if namespace is None else {"namespace": namespace}
return [
BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.
Args:
query: the query string.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
kwargs: additional keyword arguments. Currently only the free-form dict `filter`
is recognized (and goes straight to the Data API query);
others will generate a warning and be ignored.
Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries,
with their `score` attribute set to the vector similarity to the query.
"""
query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
find_filter = {**(query_filter or {}), **find_filter_ns}
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.collection.find(
filter=find_filter,
sort={"$vector": vector},
limit=ann_limit,
projection=find_projection,
include_similarity=True,
)
return [
BaseVectorStoreDriver.Entry(
id=match["_id"],
vector=match.get("$vector"),
score=match["$similarity"],
meta=match.get("meta"),
namespace=match.get("namespace"),
)
for match in matches
]
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,6 @@ nav:
- Load and Query Pinecone: "examples/load-and-query-pinecone.md"
- Load and Query Marqo: "examples/load-query-and-chat-marqo.md"
- Query a Webpage: "examples/query-webpage.md"
- RAG with Astra DB vector store: "examples/query-webpage-astra-db.md"
- Reference Guide: "reference/"
- Trade School: "https://learn.griptape.ai"
Loading

0 comments on commit 3cb1fe7

Please sign in to comment.