From 960507353f516fc83aedf3f97e5fbc49038c9092 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 10 Oct 2024 12:15:39 +0200 Subject: [PATCH] more tweaks --- .../langchain_astradb/graph_vectorstores.py | 179 ++++++++++++------ .../utils/{mmr_traversal.py => mmr_helper.py} | 0 .../astradb/langchain_astradb/vectorstores.py | 16 +- .../test_graphvectorstore.py | 1 - .../tests/unit_tests/test_mmr_helper.py | 2 +- 5 files changed, 128 insertions(+), 70 deletions(-) rename libs/astradb/langchain_astradb/utils/{mmr_traversal.py => mmr_helper.py} (100%) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 85feec8..76dd615 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -16,17 +16,13 @@ cast, ) -from langchain_community.graph_vectorstores.base import ( - METADATA_LINKS_KEY, - GraphVectorStore, - Link, - Node, -) +from langchain_community.graph_vectorstores.base import GraphVectorStore, Node +from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link from langchain_core._api import beta from langchain_core.documents import Document from typing_extensions import override -from langchain_astradb.utils.mmr_traversal import MmrHelper +from langchain_astradb.utils.mmr_helper import MmrHelper from langchain_astradb.vectorstores import AstraDBVectorStore if TYPE_CHECKING: @@ -97,7 +93,6 @@ def __init__( embedding: Embeddings, collection_name: str, metadata_incoming_links_key: str = "incoming_links", - metadata_links_key: str = METADATA_LINKS_KEY, token: str | TokenProvider | None = None, api_endpoint: str | None = None, environment: str | None = None, @@ -125,8 +120,6 @@ def __init__( collection_name: name of the Astra DB collection to create/use. metadata_incoming_links_key: document metadata key where the incoming links are stored (and indexed). - metadata_links_key: document metadata key where all the links are - stored (and not indexed). token: API token for Astra DB usage, either in the form of a string or a subclass of ``astrapy.authentication.TokenProvider``. If not provided, the environment variable @@ -206,7 +199,6 @@ def __init__( ``metric``, ``setup_mode``, ``metadata_indexing_include``, ``metadata_indexing_exclude``, ``collection_indexing_policy``. """ - self.metadata_links_key = metadata_links_key self.metadata_incoming_links_key = metadata_incoming_links_key self.embedding = embedding @@ -214,15 +206,15 @@ def __init__( # full links blob is not. if collection_indexing_policy is not None: collection_indexing_policy["allow"].append(self.metadata_incoming_links_key) - collection_indexing_policy["deny"].append(self.metadata_links_key) + collection_indexing_policy["deny"].append(METADATA_LINKS_KEY) elif metadata_indexing_include is not None: metadata_indexing_include = set(metadata_indexing_include) metadata_indexing_include.add(self.metadata_incoming_links_key) elif metadata_indexing_exclude is not None: metadata_indexing_exclude = set(metadata_indexing_exclude) - metadata_indexing_exclude.add(self.metadata_links_key) + metadata_indexing_exclude.add(METADATA_LINKS_KEY) elif not autodetect_collection: - metadata_indexing_exclude = [self.metadata_links_key] + metadata_indexing_exclude = [METADATA_LINKS_KEY] self.vector_store = AstraDBVectorStore( collection_name=collection_name, @@ -278,15 +270,15 @@ def _restore_links(self, doc: Document) -> Document: Returns: The same Document with restored links. """ - links = _deserialize_links(doc.metadata.get(self.metadata_links_key)) - doc.metadata[self.metadata_links_key] = links + links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) + doc.metadata[METADATA_LINKS_KEY] = links del doc.metadata[self.metadata_incoming_links_key] return doc def _doc_to_node(self, doc: Document) -> Node: metadata = doc.metadata.copy() - links = _deserialize_links(metadata.get(self.metadata_links_key)) - metadata[self.metadata_links_key] = links + links = _deserialize_links(metadata.get(METADATA_LINKS_KEY)) + metadata[METADATA_LINKS_KEY] = links return Node( id=doc.id, @@ -302,13 +294,19 @@ def add_nodes( nodes: Iterable[Node], **kwargs: Any, ) -> Iterable[str]: + """Add nodes to the graph store. + + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ docs = [] ids = [] for node in nodes: node_id = secrets.token_hex(8) if not node.id else node.id combined_metadata = node.metadata.copy() - combined_metadata[self.metadata_links_key] = _serialize_links(node.links) + combined_metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) combined_metadata[self.metadata_incoming_links_key] = [ _metadata_link_key(link=link) for link in _incoming_links(node=node) ] @@ -357,15 +355,26 @@ def similarity_search( self, query: str, k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ return [ self._restore_links(doc) for doc in self.vector_store.similarity_search( query=query, k=k, - filter=metadata_filter, + filter=filter, **kwargs, ) ] @@ -375,15 +384,26 @@ async def asimilarity_search( self, query: str, k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ return [ self._restore_links(doc) for doc in await self.vector_store.asimilarity_search( query=query, k=k, - filter=metadata_filter, + filter=filter, **kwargs, ) ] @@ -393,15 +413,26 @@ def similarity_search_by_vector( self, embedding: list[float], k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. + + Returns: + The list of Documents most similar to the query vector. + """ return [ self._restore_links(doc) for doc in self.vector_store.similarity_search_by_vector( embedding, k=k, - filter=metadata_filter, + filter=filter, **kwargs, ) ] @@ -411,49 +442,77 @@ async def asimilarity_search_by_vector( self, embedding: list[float], k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. + + Returns: + The list of Documents most similar to the query vector. + """ return [ self._restore_links(doc) for doc in await self.vector_store.asimilarity_search_by_vector( embedding, k=k, - filter=metadata_filter, + filter=filter, **kwargs, ) ] def metadata_search( self, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 n: int = 5, ) -> Iterable[Document]: - """Retrieve nodes based on their metadata.""" + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ return [ self._restore_links(doc) for doc in self.vector_store.metadata_search( - metadata_filter=metadata_filter, + filter=filter, n=n, ) ] async def ametadata_search( self, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 n: int = 5, ) -> Iterable[Document]: - """Retrieve nodes based on their metadata.""" + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ return [ self._restore_links(doc) for doc in await self.vector_store.ametadata_search( - metadata_filter=metadata_filter, + filter=filter, n=n, ) ] def get_node(self, node_id: str) -> Node | None: - """Get a node by its id.""" + """Retrieve a single node from the store, given its ID. + + Args: + node_id: The node ID + + Returns: + The the node if it exists. Otherwise None. + """ doc = self.vector_store.get_by_document_id(document_id=node_id) if doc is None: return None @@ -471,7 +530,7 @@ async def ammr_traversal_search( # noqa: C901 adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[Document]: """Retrieve documents from this graph store using MMR-traversal. @@ -503,8 +562,8 @@ async def ammr_traversal_search( # noqa: C901 diversity and 1 to minimum diversity. Defaults to 0.5. score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. - metadata_filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments for future use. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. """ query_embedding = self.embedding.embed_query(query) helper = MmrHelper( @@ -538,7 +597,7 @@ async def fetch_neighborhood(neighborhood: Sequence[str]) -> None: links=visited_links, query_embedding=query_embedding, k_per_link=adjacent_k, - metadata_filter=metadata_filter, + filter=filter, retrieved_docs=retrieved_docs, ) @@ -558,7 +617,7 @@ async def fetch_initial_candidates() -> None: await self.vector_store.asimilarity_search_with_embedding_id_by_vector( embedding=query_embedding, k=fetch_k, - filter=metadata_filter, + filter=filter, ) ) @@ -604,7 +663,7 @@ async def fetch_initial_candidates() -> None: links=selected_outgoing_links, query_embedding=query_embedding, k_per_link=adjacent_k, - metadata_filter=metadata_filter, + filter=filter, retrieved_docs=retrieved_docs, ) @@ -658,7 +717,7 @@ def mmr_traversal_search( adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata_filter: dict[str, Any] = {}, # noqa: B006 + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> Iterable[Document]: """Retrieve documents from this graph store using MMR-traversal. @@ -690,8 +749,8 @@ def mmr_traversal_search( diversity and 1 to minimum diversity. Defaults to 0.5. score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. - metadata_filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments for future use. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. """ async def collect_docs() -> Iterable[Document]: @@ -704,7 +763,7 @@ async def collect_docs() -> Iterable[Document]: adjacent_k=adjacent_k, lambda_mult=lambda_mult, score_threshold=score_threshold, - metadata_filter=metadata_filter, + filter=filter, **kwargs, ) return [doc async for doc in async_iter] @@ -718,7 +777,7 @@ async def atraversal_search( # noqa: C901 *, k: int = 4, depth: int = 1, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[Document]: """Retrieve documents from this knowledge store. @@ -732,8 +791,8 @@ async def atraversal_search( # noqa: C901 k: The number of Documents to return from the initial vector search. Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. - metadata_filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments for future use. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. Returns: Collection of retrieved documents. @@ -791,14 +850,14 @@ async def visit_nodes(d: int, docs: Iterable[Document]) -> None: if outgoing_links: metadata_search_tasks = [] for outgoing_link in outgoing_links: - _metadata_filter = self._get_metadata_filter( - metadata=metadata_filter, + metadata_filter = self._get_metadata_filter( + metadata=filter, outgoing_link=outgoing_link, ) metadata_search_tasks.append( asyncio.create_task( self.vector_store.ametadata_search( - metadata_filter=_metadata_filter, n=1000 + filter=metadata_filter, n=1000 ) ) ) @@ -852,7 +911,7 @@ async def visit_targets(d: int, docs: Iterable[Document]) -> None: initial_docs = self.vector_store.similarity_search( query=query, k=k, - filter=metadata_filter, + filter=filter, ) await visit_nodes(d=0, docs=initial_docs) @@ -870,7 +929,7 @@ def traversal_search( *, k: int = 4, depth: int = 1, - metadata_filter: dict[str, Any] = {}, # noqa: B006 + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> Iterable[Document]: """Retrieve documents from this knowledge store. @@ -884,8 +943,8 @@ def traversal_search( k: The number of Documents to return from the initial vector search. Defaults to 4. depth: The maximum depth of edges to traverse. Defaults to 1. - metadata_filter: Optional metadata to filter the results. - **kwargs: Additional keyword arguments for future use. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. Returns: Collection of retrieved documents. @@ -896,7 +955,7 @@ async def collect_docs() -> Iterable[Document]: query=query, k=k, depth=depth, - metadata_filter=metadata_filter, + filter=filter, **kwargs, ) return [doc async for doc in async_iter] @@ -937,7 +996,7 @@ async def _get_adjacent( query_embedding: list[float], retrieved_docs: dict[str, Document], k_per_link: int | None = None, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 ) -> Iterable[AdjacentNode]: """Return the target nodes with incoming links from any of the given links. @@ -946,7 +1005,7 @@ async def _get_adjacent( query_embedding: The query embedding. Used to rank target nodes. retrieved_docs: A cache of retrieved docs. This will be added to. k_per_link: The number of target nodes to fetch for each link. - metadata_filter: Optional metadata to filter the results. + filter: Optional metadata to filter the results. Returns: Iterable of adjacent edges. @@ -955,8 +1014,8 @@ async def _get_adjacent( tasks = [] for link in links: - _metadata_filter = self._get_metadata_filter( - metadata=metadata_filter, + metadata_filter = self._get_metadata_filter( + metadata=filter, outgoing_link=link, ) @@ -964,7 +1023,7 @@ async def _get_adjacent( self.vector_store.asimilarity_search_with_embedding_id_by_vector( embedding=query_embedding, k=k_per_link or 10, - filter=_metadata_filter, + filter=metadata_filter, ) ) diff --git a/libs/astradb/langchain_astradb/utils/mmr_traversal.py b/libs/astradb/langchain_astradb/utils/mmr_helper.py similarity index 100% rename from libs/astradb/langchain_astradb/utils/mmr_traversal.py rename to libs/astradb/langchain_astradb/utils/mmr_helper.py diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index 9f13275..b1cfcc4 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -1324,17 +1324,17 @@ async def _update_document( def metadata_search( self, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 n: int = 5, ) -> list[Document]: """Get documents via a metadata search. Args: - metadata_filter: the metadata to query for. + filter: the metadata to query for. n: the maximum number of documents to return. """ self.astra_env.ensure_db_setup() - metadata_parameter = self.filter_to_query(metadata_filter) + metadata_parameter = self.filter_to_query(filter) hits_ite = self.astra_env.collection.find( filter=metadata_parameter, projection=self.document_codec.base_projection, @@ -1345,17 +1345,17 @@ def metadata_search( async def ametadata_search( self, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 n: int = 5, ) -> Iterable[Document]: """Get documents via a metadata search. Args: - metadata_filter: the metadata to query for. + filter: the metadata to query for. n: the maximum number of documents to return. """ await self.astra_env.aensure_db_setup() - metadata_parameter = self.filter_to_query(metadata_filter) + metadata_parameter = self.filter_to_query(filter) return [ doc async for doc in ( @@ -1376,7 +1376,7 @@ def get_by_document_id(self, document_id: str) -> Document | None: document_id: The document ID Returns: - True if a document has indeed been deleted, False if ID not found. + The the document if it exists. Otherwise None. """ self.astra_env.ensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) @@ -1395,7 +1395,7 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: document_id: The document ID Returns: - True if a document has indeed been deleted, False if ID not found. + The the document if it exists. Otherwise None. """ await self.astra_env.aensure_db_setup() # self.collection is not None (by _ensure_astra_db_client) diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 8ad48bf..60bd4e1 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -164,7 +164,6 @@ def autodetect_populated_graph_vector_store_d2( embedding=embedding_d2, collection_name=ephemeral_collection_cleaner_idxall_d2, metadata_incoming_links_key="x_link_to_x", - metadata_links_key="_links", token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], diff --git a/libs/astradb/tests/unit_tests/test_mmr_helper.py b/libs/astradb/tests/unit_tests/test_mmr_helper.py index 24a858b..e9bec6d 100644 --- a/libs/astradb/tests/unit_tests/test_mmr_helper.py +++ b/libs/astradb/tests/unit_tests/test_mmr_helper.py @@ -2,7 +2,7 @@ import math -from langchain_astradb.utils.mmr_traversal import MmrHelper +from langchain_astradb.utils.mmr_helper import MmrHelper IDS = { "-1",