diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/llama_index/vector_stores/astra_db/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/llama_index/vector_stores/astra_db/base.py index 4563fc2bc5d1e..080e26b11033f 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/llama_index/vector_stores/astra_db/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/llama_index/vector_stores/astra_db/base.py @@ -1,19 +1,23 @@ """ -Astra DB Vector store index. +Astra DB Vector Store index. -An index based on a DB table with vector search capabilities, -powered by the astrapy library +An index based on a DB collection with vector search capabilities, +powered by the AstraPy library """ import json import logging from typing import Any, Dict, List, Optional, cast +from concurrent.futures import ThreadPoolExecutor from warnings import warn +from astrapy import DataAPIClient +from astrapy.exceptions import InsertManyException +from astrapy.results import UpdateResult + import llama_index.core from llama_index.core.bridge.pydantic import PrivateAttr -from astrapy.db import AstraDB from llama_index.core.indices.query.embedding_utils import get_top_k_mmr_embeddings from llama_index.core.schema import BaseNode, MetadataMode from llama_index.core.vector_stores.types import ( @@ -34,7 +38,8 @@ _logger = logging.getLogger(__name__) DEFAULT_MMR_PREFETCH_FACTOR = 4.0 -MAX_INSERT_BATCH_SIZE = 20 + +REPLACE_DOCUMENTS_MAX_THREADS = 12 NON_INDEXED_FIELDS = ["metadata._node_content", "content"] @@ -43,15 +48,14 @@ class AstraDBVectorStore(BasePydanticVectorStore): """ Astra DB Vector Store. - An abstraction of a Astra table with + An abstraction of a Astra DB collection with vector-similarity-search. Documents, and their embeddings, are stored - in an Astra table and a vector-capable index is used for searches. - The table does not need to exist beforehand: if necessary it will - be created behind the scenes. + in an Astra DB collection equipped with a vector index. + The collection, if necessary, is created when the vector store is initialized. - All Astra operations are done through the astrapy library. + All Astra operations are done through the AstraPy library. - Visit https://astra.datastax.com/signup to create an account and get an API key. + Visit https://astra.datastax.com/signup to create an account and get started. Args: collection_name (str): collection name to use. If not existing, it will be created. @@ -59,8 +63,6 @@ class AstraDBVectorStore(BasePydanticVectorStore): api_endpoint (str): The Astra DB JSON API endpoint for your database. embedding_dimension (int): length of the embedding vectors in use. namespace (Optional[str]): The namespace to use. If not provided, 'default_keyspace' - ttl_seconds (Optional[int]): expiration time for inserted entries. - Default is no expiration. Examples: `pip install llama-index-vector-stores-astra` @@ -70,7 +72,7 @@ class AstraDBVectorStore(BasePydanticVectorStore): # Create the Astra DB Vector Store object astra_db_store = AstraDBVectorStore( - collection_name="astra_v_table", + collection_name="astra_v_store", token=token, api_endpoint=api_endpoint, embedding_dimension=1536, @@ -83,9 +85,8 @@ class AstraDBVectorStore(BasePydanticVectorStore): flat_metadata: bool = True _embedding_dimension: int = PrivateAttr() - _ttl_seconds: Optional[int] = PrivateAttr() - _astra_db: Any = PrivateAttr() - _astra_db_collection: Any = PrivateAttr() + _database: Any = PrivateAttr() + _collection: Any = PrivateAttr() def __init__( self, @@ -101,46 +102,55 @@ def __init__( # Set all the required class parameters self._embedding_dimension = embedding_dimension - self._ttl_seconds = ttl_seconds - _logger.debug("Creating the Astra DB table") + if ttl_seconds is not None: + warn( + ( + "Parameter `ttl_seconds` is not supported for " + "`AstraDBVectorStore` and will be ignored." + ), + UserWarning, + stacklevel=2, + ) - # Build the Astra DB object - self._astra_db = AstraDB( - api_endpoint=api_endpoint, - token=token, - namespace=namespace, + _logger.debug("Creating the Astra DB client and database instances") + + # Build the Database object + self._database = DataAPIClient( caller_name=getattr(llama_index, "__name__", "llama_index"), caller_version=getattr(llama_index.core, "__version__", None), + ).get_database( + api_endpoint, + token=token, + namespace=namespace, ) - from astrapy.api import APIRequestError + from astrapy.exceptions import DataAPIException + + collection_indexing = {"deny": NON_INDEXED_FIELDS} try: + _logger.debug("Creating the Astra DB collection") # Create and connect to the newly created collection - self._astra_db_collection = self._astra_db.create_collection( - collection_name=collection_name, + self._collection = self._database.create_collection( + name=collection_name, dimension=embedding_dimension, - options={"indexing": {"deny": NON_INDEXED_FIELDS}}, + indexing=collection_indexing, + check_exists=False, ) - except APIRequestError: + except DataAPIException as e: # possibly the collection is preexisting and has legacy # indexing settings: verify - get_coll_response = self._astra_db.get_collections( - options={"explain": True} - ) - collections = (get_coll_response["status"] or {}).get("collections") or [] preexisting = [ - collection - for collection in collections - if collection["name"] == collection_name + coll_descriptor + for coll_descriptor in self._database.list_collections() + if coll_descriptor.name == collection_name ] if preexisting: - pre_collection = preexisting[0] # if it has no "indexing", it is a legacy collection; - # otherwise it's unexpected warn and proceed at user's risk - pre_col_options = pre_collection.get("options") or {} - if "indexing" not in pre_col_options: + # otherwise it's unexpected: warn and proceed at user's risk + pre_col_idx_opts = preexisting[0].options.indexing or {} + if not pre_col_idx_opts: warn( ( f"Collection '{collection_name}' is detected as " @@ -148,32 +158,36 @@ def __init__( "(either created manually or by older versions " "of this plugin). This implies stricter " "limitations on the amount of text" - " each entry can store. Consider reindexing anew on a" + " each entry can store. Consider indexing anew on a" " fresh collection to be able to store longer texts." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._collection = self._database.get_collection( + collection_name, ) else: - options_json = json.dumps(pre_col_options["indexing"]) - warn( - ( - f"Collection '{collection_name}' has unexpected 'indexing'" - f" settings (options.indexing = {options_json})." - " This can result in odd behaviour when running " - " metadata filtering and/or unwarranted limitations" - " on storing long texts. Consider reindexing anew on a" - " fresh collection." - ), - UserWarning, - stacklevel=2, - ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, - ) + # check if the indexing options match entirely + if pre_col_idx_opts == collection_indexing: + raise + else: + options_json = json.dumps(pre_col_idx_opts) + warn( + ( + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider indexing anew on a" + " fresh collection." + ), + UserWarning, + stacklevel=2, + ) + self._collection = self._database.get_collection( + collection_name, + ) else: # other exception raise @@ -190,8 +204,8 @@ def add( nodes: List[BaseNode]: list of node with embeddings """ - # Initialize list of objects to track - nodes_list = [] + # Initialize list of documents to insert + documents_to_insert: List[Dict[str, Any]] = [] # Process each node individually for node in nodes: @@ -203,7 +217,7 @@ def add( ) # One dictionary of node data per node - nodes_list.append( + documents_to_insert.append( { "_id": node.node_id, "content": node.get_content(metadata_mode=MetadataMode.NONE), @@ -212,26 +226,61 @@ def add( } ) - # Log the number of rows being added - _logger.debug(f"Adding {len(nodes_list)} rows to table") + # Log the number of documents being added + _logger.debug(f"Adding {len(documents_to_insert)} documents to the collection") - # Initialize an empty list to hold the batches - batched_list = [] + # perform an AstraPy insert_many, catching exceptions for overwriting docs + ids_to_replace: List[int] + try: + self._collection.insert_many( + documents_to_insert, + ordered=False, + ) + ids_to_replace = [] + except InsertManyException as err: + inserted_ids_set = set(err.partial_result.inserted_ids) + ids_to_replace = [ + document["_id"] + for document in documents_to_insert + if document["_id"] not in inserted_ids_set + ] + _logger.debug( + f"Detected {len(ids_to_replace)} non-inserted documents, trying replace_one" + ) - # Iterate over the node_list in steps of MAX_INSERT_BATCH_SIZE - for i in range(0, len(nodes_list), MAX_INSERT_BATCH_SIZE): - # Append a slice of node_list to the batched_list - batched_list.append(nodes_list[i : i + MAX_INSERT_BATCH_SIZE]) + # if necessary, replace docs for the non-inserted ids + if ids_to_replace: + documents_to_replace = [ + document + for document in documents_to_insert + if document["_id"] in ids_to_replace + ] - # Perform the bulk insert - for i, batch in enumerate(batched_list): - _logger.debug(f"Processing batch #{i + 1} of size {len(batch)}") + with ThreadPoolExecutor( + max_workers=REPLACE_DOCUMENTS_MAX_THREADS + ) as executor: - # Go to astrapy to perform the bulk insert - self._astra_db_collection.insert_many(batch) + def _replace_document(document: Dict[str, Any]) -> UpdateResult: + return self._collection.replace_one( + {"_id": document["_id"]}, + document, + ) + + replace_results = executor.map( + _replace_document, + documents_to_replace, + ) + + replaced_count = sum(r_res.update_info["n"] for r_res in replace_results) + if replaced_count != len(ids_to_replace): + missing = len(ids_to_replace) - replaced_count + raise ValueError( + "AstraDBVectorStore.add could not insert all requested " + f"documents ({missing} failed replace_one calls)" + ) # Return the list of ids - return [str(n["_id"]) for n in nodes_list] + return [str(n["_id"]) for n in documents_to_insert] def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: """ @@ -241,14 +290,27 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: ref_doc_id (str): The id of the document to delete. """ - _logger.debug("Deleting a document from the Astra table") + _logger.debug("Deleting a document from the Astra DB collection") + + if delete_kwargs: + args_desc = ", ".join( + f"'{kwarg}'" for kwarg in sorted(delete_kwargs.keys()) + ) + warn( + ( + "AstraDBVectorStore.delete call got unsupported " + f"named argument(s): {args_desc}." + ), + UserWarning, + stacklevel=2, + ) - self._astra_db_collection.delete_one(id=ref_doc_id, **delete_kwargs) + self._collection.delete_one({"_id": ref_doc_id}) @property def client(self) -> Any: - """Return the underlying Astra vector table object.""" - return self._astra_db_collection + """Return the underlying Astra DB `astrapy.Collection` object.""" + return self._collection @staticmethod def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]: @@ -287,15 +349,19 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul else: query_metadata = {} + matches: List[Dict[str, Any]] + # Get the scores depending on the query mode if query.mode == VectorStoreQueryMode.DEFAULT: # Call the vector_find method of AstraPy - matches = self._astra_db_collection.vector_find( - vector=query_embedding, - limit=query.similarity_top_k, - filter=query_metadata, - fields=["*"], - include_similarity=True, + matches = list( + self._collection.find( + filter=query_metadata, + projection={"*": True}, + limit=query.similarity_top_k, + sort={"$vector": query_embedding}, + include_similarity=True, + ) ) # Get the scores associated with each @@ -322,11 +388,13 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul prefetch_k = max(prefetch_k0, query.similarity_top_k) # Call AstraPy to fetch them (similarity from DB not needed here) - prefetch_matches = self._astra_db_collection.vector_find( - vector=query_embedding, - limit=prefetch_k, - filter=query_metadata, - fields=["*"], + prefetch_matches = list( + self._collection.find( + filter=query_metadata, + projection={"*": True}, + limit=prefetch_k, + sort={"$vector": query_embedding}, + ) ) # Get the MMR threshold diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/pyproject.toml index 915b84d638622..5bbdbe4517b48 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-astra-db" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -astrapy = "^1" +astrapy = "^1.3" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" @@ -40,7 +40,7 @@ jupyter = "^1.0.0" mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" -pytest = "7.2.1" +pytest = "~8.0.0" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/tests/test_astra_db.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/tests/test_astra_db.py index ddb6338f126cf..a886c6e5bb54b 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/tests/test_astra_db.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-astra-db/tests/test_astra_db.py @@ -2,17 +2,15 @@ import pytest from typing import Iterable -import astrapy from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode from llama_index.core.vector_stores.types import VectorStoreQuery from llama_index.vector_stores.astra_db import AstraDBVectorStore -print(f"astrapy detected: {astrapy.__version__}") - # env variables ASTRA_DB_APPLICATION_TOKEN = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") ASTRA_DB_API_ENDPOINT = os.getenv("ASTRA_DB_API_ENDPOINT", "") +ASTRA_DB_KEYSPACE = os.getenv("ASTRA_DB_KEYSPACE") @pytest.fixture(scope="module") @@ -21,11 +19,13 @@ def astra_db_store() -> Iterable[AstraDBVectorStore]: token=ASTRA_DB_APPLICATION_TOKEN, api_endpoint=ASTRA_DB_API_ENDPOINT, collection_name="test_collection", + namespace=ASTRA_DB_KEYSPACE, embedding_dimension=2, ) + store._collection.delete_many({}) yield store - store._astra_db.delete_collection("test_collection") + store._collection.drop() @pytest.mark.skipif( @@ -33,6 +33,7 @@ def astra_db_store() -> Iterable[AstraDBVectorStore]: reason="missing Astra DB credentials", ) def test_astra_db_create_and_crud(astra_db_store: AstraDBVectorStore) -> None: + """Test basic creation and insertion/deletion of a node.""" astra_db_store.add( [ TextNode( @@ -54,8 +55,46 @@ def test_astra_db_create_and_crud(astra_db_store: AstraDBVectorStore) -> None: reason="missing Astra DB credentials", ) def test_astra_db_queries(astra_db_store: AstraDBVectorStore) -> None: + """Test basic querying.""" query = VectorStoreQuery(query_embedding=[1, 1], similarity_top_k=3) astra_db_store.query( query, ) + + +@pytest.mark.skipif( + ASTRA_DB_APPLICATION_TOKEN == "" or ASTRA_DB_API_ENDPOINT == "", + reason="missing Astra DB credentials", +) +def test_astra_db_insertions(astra_db_store: AstraDBVectorStore) -> None: + """Test massive insertion with overwrites.""" + all_ids = list(range(150)) + nodes0 = [ + TextNode( + text=f"OLD_node {idx}", + id_=f"n_{idx}", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="doc_0")}, + embedding=[0.5, 0.5], + ) + for idx in all_ids[60:80] + all_ids[130:140] + ] + nodes = [ + TextNode( + text=f"node {idx}", + id_=f"n_{idx}", + relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="doc_0")}, + embedding=[0.5, 0.5], + ) + for idx in all_ids + ] + + astra_db_store.add(nodes0) + found_contents0 = [doc["content"] for doc in astra_db_store._collection.find({})] + assert all(f_content[:4] == "OLD_" for f_content in found_contents0) + assert len(found_contents0) == len(nodes0) + + astra_db_store.add(nodes) + found_contents = [doc["content"] for doc in astra_db_store._collection.find({})] + assert all(f_content[:5] == "node " for f_content in found_contents) + assert len(found_contents) == len(nodes)