Skip to content

Commit

Permalink
Update Astra DB vector store to use modern astrapy library (#14407)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus authored Jun 26, 2024
1 parent 48d22f4 commit 831947a
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 98 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""
Astra DB Vector store index.
Astra DB Vector Store index.
An index based on a DB table with vector search capabilities,
powered by the astrapy library
An index based on a DB collection with vector search capabilities,
powered by the AstraPy library
"""

import json
import logging
from typing import Any, Dict, List, Optional, cast
from concurrent.futures import ThreadPoolExecutor
from warnings import warn

from astrapy import DataAPIClient
from astrapy.exceptions import InsertManyException
from astrapy.results import UpdateResult

import llama_index.core
from llama_index.core.bridge.pydantic import PrivateAttr
from astrapy.db import AstraDB
from llama_index.core.indices.query.embedding_utils import get_top_k_mmr_embeddings
from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.core.vector_stores.types import (
Expand All @@ -34,7 +38,8 @@
_logger = logging.getLogger(__name__)

DEFAULT_MMR_PREFETCH_FACTOR = 4.0
MAX_INSERT_BATCH_SIZE = 20

REPLACE_DOCUMENTS_MAX_THREADS = 12

NON_INDEXED_FIELDS = ["metadata._node_content", "content"]

Expand All @@ -43,24 +48,21 @@ class AstraDBVectorStore(BasePydanticVectorStore):
"""
Astra DB Vector Store.
An abstraction of a Astra table with
An abstraction of a Astra DB collection with
vector-similarity-search. Documents, and their embeddings, are stored
in an Astra table and a vector-capable index is used for searches.
The table does not need to exist beforehand: if necessary it will
be created behind the scenes.
in an Astra DB collection equipped with a vector index.
The collection, if necessary, is created when the vector store is initialized.
All Astra operations are done through the astrapy library.
All Astra operations are done through the AstraPy library.
Visit https://astra.datastax.com/signup to create an account and get an API key.
Visit https://astra.datastax.com/signup to create an account and get started.
Args:
collection_name (str): collection name to use. If not existing, it will be created.
token (str): The Astra DB Application Token to use.
api_endpoint (str): The Astra DB JSON API endpoint for your database.
embedding_dimension (int): length of the embedding vectors in use.
namespace (Optional[str]): The namespace to use. If not provided, 'default_keyspace'
ttl_seconds (Optional[int]): expiration time for inserted entries.
Default is no expiration.
Examples:
`pip install llama-index-vector-stores-astra`
Expand All @@ -70,7 +72,7 @@ class AstraDBVectorStore(BasePydanticVectorStore):
# Create the Astra DB Vector Store object
astra_db_store = AstraDBVectorStore(
collection_name="astra_v_table",
collection_name="astra_v_store",
token=token,
api_endpoint=api_endpoint,
embedding_dimension=1536,
Expand All @@ -83,9 +85,8 @@ class AstraDBVectorStore(BasePydanticVectorStore):
flat_metadata: bool = True

_embedding_dimension: int = PrivateAttr()
_ttl_seconds: Optional[int] = PrivateAttr()
_astra_db: Any = PrivateAttr()
_astra_db_collection: Any = PrivateAttr()
_database: Any = PrivateAttr()
_collection: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -101,79 +102,92 @@ def __init__(

# Set all the required class parameters
self._embedding_dimension = embedding_dimension
self._ttl_seconds = ttl_seconds

_logger.debug("Creating the Astra DB table")
if ttl_seconds is not None:
warn(
(
"Parameter `ttl_seconds` is not supported for "
"`AstraDBVectorStore` and will be ignored."
),
UserWarning,
stacklevel=2,
)

# Build the Astra DB object
self._astra_db = AstraDB(
api_endpoint=api_endpoint,
token=token,
namespace=namespace,
_logger.debug("Creating the Astra DB client and database instances")

# Build the Database object
self._database = DataAPIClient(
caller_name=getattr(llama_index, "__name__", "llama_index"),
caller_version=getattr(llama_index.core, "__version__", None),
).get_database(
api_endpoint,
token=token,
namespace=namespace,
)

from astrapy.api import APIRequestError
from astrapy.exceptions import DataAPIException

collection_indexing = {"deny": NON_INDEXED_FIELDS}

try:
_logger.debug("Creating the Astra DB collection")
# Create and connect to the newly created collection
self._astra_db_collection = self._astra_db.create_collection(
collection_name=collection_name,
self._collection = self._database.create_collection(
name=collection_name,
dimension=embedding_dimension,
options={"indexing": {"deny": NON_INDEXED_FIELDS}},
indexing=collection_indexing,
check_exists=False,
)
except APIRequestError:
except DataAPIException as e:
# possibly the collection is preexisting and has legacy
# indexing settings: verify
get_coll_response = self._astra_db.get_collections(
options={"explain": True}
)
collections = (get_coll_response["status"] or {}).get("collections") or []
preexisting = [
collection
for collection in collections
if collection["name"] == collection_name
coll_descriptor
for coll_descriptor in self._database.list_collections()
if coll_descriptor.name == collection_name
]
if preexisting:
pre_collection = preexisting[0]
# if it has no "indexing", it is a legacy collection;
# otherwise it's unexpected warn and proceed at user's risk
pre_col_options = pre_collection.get("options") or {}
if "indexing" not in pre_col_options:
# otherwise it's unexpected: warn and proceed at user's risk
pre_col_idx_opts = preexisting[0].options.indexing or {}
if not pre_col_idx_opts:
warn(
(
f"Collection '{collection_name}' is detected as "
"having indexing turned on for all fields "
"(either created manually or by older versions "
"of this plugin). This implies stricter "
"limitations on the amount of text"
" each entry can store. Consider reindexing anew on a"
" each entry can store. Consider indexing anew on a"
" fresh collection to be able to store longer texts."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
self._collection = self._database.get_collection(
collection_name,
)
else:
options_json = json.dumps(pre_col_options["indexing"])
warn(
(
f"Collection '{collection_name}' has unexpected 'indexing'"
f" settings (options.indexing = {options_json})."
" This can result in odd behaviour when running "
" metadata filtering and/or unwarranted limitations"
" on storing long texts. Consider reindexing anew on a"
" fresh collection."
),
UserWarning,
stacklevel=2,
)
self._astra_db_collection = self._astra_db.collection(
collection_name=collection_name,
)
# check if the indexing options match entirely
if pre_col_idx_opts == collection_indexing:
raise
else:
options_json = json.dumps(pre_col_idx_opts)
warn(
(
f"Collection '{collection_name}' has unexpected 'indexing'"
f" settings (options.indexing = {options_json})."
" This can result in odd behaviour when running "
" metadata filtering and/or unwarranted limitations"
" on storing long texts. Consider indexing anew on a"
" fresh collection."
),
UserWarning,
stacklevel=2,
)
self._collection = self._database.get_collection(
collection_name,
)
else:
# other exception
raise
Expand All @@ -190,8 +204,8 @@ def add(
nodes: List[BaseNode]: list of node with embeddings
"""
# Initialize list of objects to track
nodes_list = []
# Initialize list of documents to insert
documents_to_insert: List[Dict[str, Any]] = []

# Process each node individually
for node in nodes:
Expand All @@ -203,7 +217,7 @@ def add(
)

# One dictionary of node data per node
nodes_list.append(
documents_to_insert.append(
{
"_id": node.node_id,
"content": node.get_content(metadata_mode=MetadataMode.NONE),
Expand All @@ -212,26 +226,61 @@ def add(
}
)

# Log the number of rows being added
_logger.debug(f"Adding {len(nodes_list)} rows to table")
# Log the number of documents being added
_logger.debug(f"Adding {len(documents_to_insert)} documents to the collection")

# Initialize an empty list to hold the batches
batched_list = []
# perform an AstraPy insert_many, catching exceptions for overwriting docs
ids_to_replace: List[int]
try:
self._collection.insert_many(
documents_to_insert,
ordered=False,
)
ids_to_replace = []
except InsertManyException as err:
inserted_ids_set = set(err.partial_result.inserted_ids)
ids_to_replace = [
document["_id"]
for document in documents_to_insert
if document["_id"] not in inserted_ids_set
]
_logger.debug(
f"Detected {len(ids_to_replace)} non-inserted documents, trying replace_one"
)

# Iterate over the node_list in steps of MAX_INSERT_BATCH_SIZE
for i in range(0, len(nodes_list), MAX_INSERT_BATCH_SIZE):
# Append a slice of node_list to the batched_list
batched_list.append(nodes_list[i : i + MAX_INSERT_BATCH_SIZE])
# if necessary, replace docs for the non-inserted ids
if ids_to_replace:
documents_to_replace = [
document
for document in documents_to_insert
if document["_id"] in ids_to_replace
]

# Perform the bulk insert
for i, batch in enumerate(batched_list):
_logger.debug(f"Processing batch #{i + 1} of size {len(batch)}")
with ThreadPoolExecutor(
max_workers=REPLACE_DOCUMENTS_MAX_THREADS
) as executor:

# Go to astrapy to perform the bulk insert
self._astra_db_collection.insert_many(batch)
def _replace_document(document: Dict[str, Any]) -> UpdateResult:
return self._collection.replace_one(
{"_id": document["_id"]},
document,
)

replace_results = executor.map(
_replace_document,
documents_to_replace,
)

replaced_count = sum(r_res.update_info["n"] for r_res in replace_results)
if replaced_count != len(ids_to_replace):
missing = len(ids_to_replace) - replaced_count
raise ValueError(
"AstraDBVectorStore.add could not insert all requested "
f"documents ({missing} failed replace_one calls)"
)

# Return the list of ids
return [str(n["_id"]) for n in nodes_list]
return [str(n["_id"]) for n in documents_to_insert]

def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Expand All @@ -241,14 +290,27 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
ref_doc_id (str): The id of the document to delete.
"""
_logger.debug("Deleting a document from the Astra table")
_logger.debug("Deleting a document from the Astra DB collection")

if delete_kwargs:
args_desc = ", ".join(
f"'{kwarg}'" for kwarg in sorted(delete_kwargs.keys())
)
warn(
(
"AstraDBVectorStore.delete call got unsupported "
f"named argument(s): {args_desc}."
),
UserWarning,
stacklevel=2,
)

self._astra_db_collection.delete_one(id=ref_doc_id, **delete_kwargs)
self._collection.delete_one({"_id": ref_doc_id})

@property
def client(self) -> Any:
"""Return the underlying Astra vector table object."""
return self._astra_db_collection
"""Return the underlying Astra DB `astrapy.Collection` object."""
return self._collection

@staticmethod
def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
Expand Down Expand Up @@ -287,15 +349,19 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
else:
query_metadata = {}

matches: List[Dict[str, Any]]

# Get the scores depending on the query mode
if query.mode == VectorStoreQueryMode.DEFAULT:
# Call the vector_find method of AstraPy
matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=query.similarity_top_k,
filter=query_metadata,
fields=["*"],
include_similarity=True,
matches = list(
self._collection.find(
filter=query_metadata,
projection={"*": True},
limit=query.similarity_top_k,
sort={"$vector": query_embedding},
include_similarity=True,
)
)

# Get the scores associated with each
Expand All @@ -322,11 +388,13 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
prefetch_k = max(prefetch_k0, query.similarity_top_k)

# Call AstraPy to fetch them (similarity from DB not needed here)
prefetch_matches = self._astra_db_collection.vector_find(
vector=query_embedding,
limit=prefetch_k,
filter=query_metadata,
fields=["*"],
prefetch_matches = list(
self._collection.find(
filter=query_metadata,
projection={"*": True},
limit=prefetch_k,
sort={"$vector": query_embedding},
)
)

# Get the MMR threshold
Expand Down
Loading

0 comments on commit 831947a

Please sign in to comment.