diff --git a/libs/community/langchain_community/chains/graph_qa/arangodb.py b/libs/community/langchain_community/chains/graph_qa/arangodb.py index 933cf91737cb5..41d4f6b57b097 100644 --- a/libs/community/langchain_community/chains/graph_qa/arangodb.py +++ b/libs/community/langchain_community/chains/graph_qa/arangodb.py @@ -57,6 +57,10 @@ class ArangoGraphQAChain(Chain): # Specify the maximum amount of AQL Generation attempts that should be made max_aql_generation_attempts: int = 3 + # Specify whether to execute the generated AQL Query + # If False, the AQL Query is only explained & returned, not executed + execute_aql_query: bool = True + allow_dangerous_requests: bool = False """Forced user opt-in to acknowledge that the chain can make dangerous requests. @@ -155,6 +159,11 @@ def _call( AQL Query Execution Error. Defaults to 3. :type max_aql_generation_attempts: int """ + try: + from arango import AQLQueryExecuteError, AQLQueryExplainError + except ImportError: + raise ImportError("ArangoDB not installed, please install with `pip install python-arango`.") + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() user_input = inputs[self.input_key] @@ -176,48 +185,34 @@ def _call( aql_result = None aql_generation_attempt = 1 - while ( - aql_result is None - and aql_generation_attempt < self.max_aql_generation_attempts + 1 - ): + aql_execution_func = self.graph.query if self.execute_aql_query else self.graph.explain + + while aql_result is None and aql_generation_attempt < self.max_aql_generation_attempts + 1: ##################### # Extract AQL Query # pattern = r"```(?i:aql)?(.*?)```" matches = re.findall(pattern, aql_generation_output, re.DOTALL) if not matches: - _run_manager.on_text( - "Invalid Response: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - aql_generation_output, color="red", end="\n", verbose=self.verbose - ) + _run_manager.on_text("Invalid Response: ", end="\n", verbose=self.verbose) + _run_manager.on_text(aql_generation_output, color="red", end="\n", verbose=self.verbose) raise ValueError(f"Response is Invalid: {aql_generation_output}") aql_query = matches[0] ##################### - _run_manager.on_text( - f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose - ) - _run_manager.on_text( - aql_query, color="green", end="\n", verbose=self.verbose - ) + _run_manager.on_text(f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose) + _run_manager.on_text(aql_query, color="green", end="\n", verbose=self.verbose) - ##################### - # Execute AQL Query # - from arango import AQLQueryExecuteError + ############################# + # Execute/Explain AQL Query # try: - aql_result = self.graph.query(aql_query, self.top_k) - except AQLQueryExecuteError as e: + aql_result = aql_execution_func(aql_query, self.top_k) + except (AQLQueryExecuteError, AQLQueryExplainError) as e: aql_error = e.error_message - _run_manager.on_text( - "AQL Query Execution Error: ", end="\n", verbose=self.verbose - ) - _run_manager.on_text( - aql_error, color="yellow", end="\n\n", verbose=self.verbose - ) + _run_manager.on_text("AQL Query Execution Error: ", end="\n", verbose=self.verbose) + _run_manager.on_text(aql_error, color="yellow", end="\n\n", verbose=self.verbose) ######################## # Retry AQL Generation # @@ -243,10 +238,14 @@ def _call( """ raise ValueError(m) - _run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose) - _run_manager.on_text( - str(aql_result), color="green", end="\n", verbose=self.verbose - ) + text = "AQL Result:" if self.execute_aql_query else "AQL Explain:" + _run_manager.on_text(text, end="\n", verbose=self.verbose) + _run_manager.on_text(str(aql_result), color="green", end="\n", verbose=self.verbose) + + if not self.execute_aql_query: + result = {self.output_key: aql_query} + + return result ######################## # Interpret AQL Result # diff --git a/libs/community/langchain_community/graphs/arangodb_graph.py b/libs/community/langchain_community/graphs/arangodb_graph.py index dd2ad16614f86..20adea4e1e834 100644 --- a/libs/community/langchain_community/graphs/arangodb_graph.py +++ b/libs/community/langchain_community/graphs/arangodb_graph.py @@ -1,11 +1,38 @@ +import itertools +import json import os +from collections import defaultdict from math import ceil from typing import Any, Dict, List, Optional +from uuid import uuid4 +from langchain_community.graphs.graph_document import Document, GraphDocument, Node +from langchain_community.graphs.neo4j_graph import value_sanitize +from langchain_community.graphs.graph_store import GraphStore -class ArangoGraph: +try: + from arango.database import Database + from arango.graph import Graph + + ARANGO_INSTALLED = True +except ImportError: + print("ArangoDB not installed, please install with `pip install python-arango`.") + ARANGO_INSTALLED = False + + +class ArangoGraph(GraphStore): """ArangoDB wrapper for graph operations. + Parameters: + db (arango.database.Database): ArangoDB database instance. + sanitize (bool): A flag to indicate whether to remove lists with + more than 128 elements from results. Useful for removing + embedding-like properties from database responses. Default is False. + include_examples (bool): A flag whether to scan the database for + example values and use them in the graph schema. Default is True. + graph_name (str): The name of the graph to use to generate the schema. If + None, the entire database will be used. + *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. Failure to do so may result in data corruption or loss, since the calling @@ -18,61 +45,102 @@ class ArangoGraph: See https://python.langchain.com/docs/security for more information. """ - def __init__(self, db: Any) -> None: - """Create a new ArangoDB graph wrapper instance.""" - self.set_db(db) - self.set_schema() + def __init__( + self, + db: Database, + include_examples: bool = True, + graph_name: Optional[str] = None, + ) -> None: + if not ARANGO_INSTALLED: + m = "ArangoDB not installed, please install with `pip install python-arango`." + raise ImportError(m) + + self.__db: Database = db + self.__schema = self.generate_schema(include_examples=include_examples, graph_name=graph_name) @property - def db(self) -> Any: + def db(self) -> "Database": return self.__db @property def schema(self) -> Dict[str, Any]: + """Returns the schema of the Graph Database as a structured object""" return self.__schema - def set_db(self, db: Any) -> None: - from arango.database import Database - - if not isinstance(db, Database): - msg = "**db** parameter must inherit from arango.database.Database" - raise TypeError(msg) - - self.__db: Database = db - self.set_schema() + @property + def get_structured_schema(self) -> Dict[str, Any]: + """Returns the schema of the Graph Database as a structured object""" + return self.__schema - def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None: - """ - Set the schema of the ArangoDB Database. - Auto-generates Schema if **schema** is None. - """ - self.__schema = self.generate_schema() if schema is None else schema + @property + def get_schema(self) -> str: + """Returns the schema of the Graph Database as a string""" + return json.dumps(self.__schema) + + def set_schema(self, schema: Dict[str, Any]) -> None: + """Sets a custom schema for the ArangoDB Database.""" + self.__schema = schema + + def refresh_schema( + self, + sample_ratio: float = 0, + graph_name: Optional[str] = None, + include_examples: bool = True, + ) -> None: + """Refresh the graph schema information.""" + self.__schema = self.generate_schema(sample_ratio, graph_name, include_examples) def generate_schema( - self, sample_ratio: float = 0 + self, + sample_ratio: float = 0, + graph_name: Optional[str] = None, + include_examples: bool = True, ) -> Dict[str, List[Dict[str, Any]]]: """ Generates the schema of the ArangoDB Database and returns it - User can specify a **sample_ratio** (0 to 1) to determine the + + Parameters: + sample_ratio (float): A ratio (0 to 1) to determine the ratio of documents/edges used (in relation to the Collection size) to render each Collection Schema. + graph_name (str): The name of the graph to use to generate the schema. If + None, the entire database will be used. + include_examples (bool): A flag whether to scan the database for + example values and use them in the graph schema. Default is True. """ if not 0 <= sample_ratio <= 1: raise ValueError("**sample_ratio** value must be in between 0 to 1") - # Stores the Edge Relationships between each ArangoDB Document Collection - graph_schema: List[Dict[str, Any]] = [ - {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]} - for g in self.db.graphs() - ] + if graph_name: + # Fetch a single graph + graph: Graph = self.db.graph(graph_name) + edge_definitions = graph.edge_definitions() + + graph_schema: List[Dict[str, Any]] = [{"name": graph_name, "edge_definitions": edge_definitions}] + + # Fetch graph-specific collections + collection_names = set(graph.vertex_collections()) + for edge_definition in edge_definitions: + collection_names.add(edge_definition["edge_collection"]) + + else: + # Fetch all graphs + graph_schema: List[Dict[str, Any]] = [ + {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]} for g in self.db.graphs() + ] + + # Fetch all collections + collection_names = {collection["name"] for collection in self.db.collections()} # Stores the schema of every ArangoDB Document/Edge collection collection_schema: List[Dict[str, Any]] = [] - for collection in self.db.collections(): if collection["system"]: continue + if collection["name"] not in collection_names: + continue + # Extract collection name, type, and size col_name: str = collection["name"] col_type: str = collection["type"] @@ -86,36 +154,170 @@ def generate_schema( limit_amount = ceil(sample_ratio * col_size) or 1 aql = f""" - FOR doc in {col_name} + FOR doc in @@col_name LIMIT {limit_amount} RETURN doc """ doc: Dict[str, Any] properties: List[Dict[str, str]] = [] - for doc in self.__db.aql.execute(aql): + for doc in self.db.aql.execute(aql, bind_vars={"@col_name": col_name}): for key, value in doc.items(): properties.append({"name": key, "type": type(value).__name__}) - collection_schema.append( - { - "collection_name": col_name, - "collection_type": col_type, - f"{col_type}_properties": properties, - f"example_{col_type}": doc, - } - ) + collection_schema_entry = { + "name": col_name, + "type": col_type, + f"properties": properties, + } - return {"Graph Schema": graph_schema, "Collection Schema": collection_schema} + if include_examples: + collection_schema_entry[f"example"] = value_sanitize(doc) - def query( - self, query: str, top_k: Optional[int] = None, **kwargs: Any - ) -> List[Dict[str, Any]]: - """Query the ArangoDB database.""" - import itertools + collection_schema.append(collection_schema_entry) + + return {"graph_schema": graph_schema, "collection_schema": collection_schema} + def query(self, query: str, top_k: Optional[int] = None, **kwargs: Any) -> List[Dict[str, Any]]: + """Query the ArangoDB database.""" cursor = self.__db.aql.execute(query, **kwargs) - return [doc for doc in itertools.islice(cursor, top_k)] + return [value_sanitize(doc) for doc in itertools.islice(cursor, top_k)] + + def explain(self, query: str, *args: Any, **kwargs: Any) -> List[Dict[str, Any]]: + """Explain an AQL query without executing it.""" + return self.__db.aql.explain(query) + + def add_graph_documents( + self, + graph_documents: List[GraphDocument], + include_source: bool = False, + batch_size: int = 1000, + graph_name: Optional[str] = None, + ) -> None: + """ + This method constructs nodes and relationships in the graph based on the + provided GraphDocument objects. + + Parameters: + - graph_documents (List[GraphDocument]): A list of GraphDocument objects + that contain the nodes and relationships to be added to the graph. Each + GraphDocument should encapsulate the structure of part of the graph, + including nodes, relationships, and the source document information. + - include_source (bool, optional): If True, stores the source document + and links it to nodes in the graph using the MENTIONS relationship. + This is useful for tracing back the origin of data. Merges source + documents based on the `id` property from the source document metadata + if available; otherwise it calculates the MD5 hash of `page_content` + for merging process. Defaults to False. + - graph_name (str): The name of the ArangoDB General Graph to create. If None, + no graph will be created. + """ + if not graph_documents: + return + + nodes = defaultdict(list) + edges = defaultdict(list) + edge_definitions_dict = defaultdict(lambda: defaultdict(set)) + + if include_source: + if not self.db.has_collection("MENTIONS"): + self.db.create_collection("MENTIONS", edge=True) + + if not self.db.has_collection("GraphDocumentSource"): + self.db.create_collection("GraphDocumentSource") + + edge_definitions_dict["MENTIONS"] = { + "edge_collection": "MENTIONS", + "from_vertex_collections": {"GraphDocumentSource"}, + "to_vertex_collections": set(), + } + + for document in graph_documents: + for i, node in enumerate(document.nodes, 1): + node_data = {"_key": str(node.id), **node.properties} + nodes[node.type].append(node_data) + + if i % batch_size == 0: + self.__import_data(nodes, is_edge=False) + + self.__import_data(nodes, is_edge=False) + + # Insert relationships + for i, rel in enumerate(document.relationships, 1): + source: Node = rel.source + target: Node = rel.target + + edge_definitions_dict[rel.type]["edge_collection"].add(rel.type) + edge_definitions_dict[rel.type]["from_vertex_collections"].add(source.type) + edge_definitions_dict[rel.type]["to_vertex_collections"].add(target.type) + + rel_data = { + "_from": f"{source.type}/{source.id}", + "_to": f"{target.type}/{target.id}", + **rel.properties, + } + + edges[rel.type].append(rel_data) + + if i % batch_size == 0: + self.__import_data(edges, is_edge=True) + + self.__import_data(edges, is_edge=True) + + # Insert source document if required + if include_source: + doc_source: Document = document.source + + _key = str(doc_source.metadata.get("id", uuid4())) + source_data = { + "_key": _key, + "text": doc_source.page_content, + "metadata": doc_source.metadata, + } + + self.db.collection("GraphDocumentSource").insert(source_data, overwrite=True) + + mentions = [] + mentions_col = self.db.collection("MENTIONS") + for i, node in enumerate(document.nodes, 1): + edge_definitions_dict["MENTIONS"]["to_vertex_collections"].add(node.type) + + mentions.append( + { + "_from": f"GraphDocumentSource/{_key}", + "_to": f"{node.type}/{str(node.id)}", + } + ) + + if i % batch_size == 0: + mentions_col.import_bulk(mentions, on_duplicate="update") + mentions.clear() + + mentions_col.import_bulk(mentions, on_duplicate="update") + + if graph_name: + edge_definitions = [] + for k, v in edge_definitions_dict.items(): + edge_definitions.append( + { + "edge_collection": k, + "from_vertex_collections": list(v["from_vertex_collections"]), + "to_vertex_collections": list(v["to_vertex_collections"]), + } + ) + + if not self.db.has_graph(graph_name): + self.db.create_graph(graph_name, edge_definitions) + else: + graph = self.db.graph(graph_name) + for e_d in edge_definitions: + if not graph.has_edge_definition(e_d["edge_collection"]): + graph.create_edge_definition(*e_d.values()) + else: + graph.replace_edge_definition(*e_d.values()) + + # Refresh schema after insertions + self.refresh_schema() @classmethod def from_db_credentials( @@ -140,11 +342,18 @@ def from_db_credentials( Returns: An arango.database.StandardDatabase. """ - db = get_arangodb_client( - url=url, dbname=dbname, username=username, password=password - ) + db = get_arangodb_client(url=url, dbname=dbname, username=username, password=password) return cls(db) + def __import_data(self, data: Dict[str, List[Dict[str, Any]]], is_edge: bool) -> None: + for collection, batch in data.items(): + if not self.db.has_collection(collection): + self.db.create_collection(collection, edge=is_edge) + + self.db.collection(collection).import_bulk(batch, on_duplicate="update") + + data.clear() + def get_arangodb_client( url: Optional[str] = None, @@ -170,9 +379,8 @@ def get_arangodb_client( try: from arango import ArangoClient except ImportError as e: - raise ImportError( - "Unable to import arango, please install with `pip install python-arango`." - ) from e + m = "Unable to import arango, please install with `pip install python-arango`." + raise ImportError(m) from e _url: str = url or os.environ.get("ARANGODB_URL", "http://localhost:8529") # type: ignore[assignment] _dbname: str = dbname or os.environ.get("ARANGODB_DBNAME", "_system") # type: ignore[assignment] diff --git a/libs/community/langchain_community/vectorstores/arangodb_vector.py b/libs/community/langchain_community/vectorstores/arangodb_vector.py new file mode 100644 index 0000000000000..2f0b844357a3e --- /dev/null +++ b/libs/community/langchain_community/vectorstores/arangodb_vector.py @@ -0,0 +1,462 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type +from uuid import uuid4 + +import numpy as np +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore +from packaging import version + +try: + from arango.database import Database + from arango.exceptions import ArangoServerError + from arango.graph import Graph + + ARANGO_INSTALLED = True +except ImportError: + print("ArangoDB not installed, please install with `pip install python-arango`.") + ARANGO_INSTALLED = False + +from langchain_community.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance + +DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE +DISTANCE_MAPPING = { + DistanceStrategy.EUCLIDEAN_DISTANCE: "l2", + DistanceStrategy.COSINE: "cosine", +} + + +class SearchType(str, Enum): + """Enumerator of the Distance strategies.""" + + VECTOR = "vector" + # HYBRID = "hybrid" # TODO + + +DEFAULT_SEARCH_TYPE = SearchType.VECTOR + + +class ArangoVector(VectorStore): + """ArangoDB vector index. + + To use this, you should have the `python-arango` python package installed. + + Args: + embedding: Any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + database: The python-arango database instance. + embedding_dimension: The dimension of the to-be-inserted embedding vectors. + search_type: The type of search to be performed, currently only 'vector' is supported. + collection_name: The name of the collection to use. (default: "documents") + index_name: The name of the vector index to use. (default: "vector_index") + text_field: The field name storing the text. (default: "text") + embedding_field: The field name storing the embedding vector. (default: "embedding") + distance_strategy: The distance strategy to use. (default: "COSINE") + num_centroids: The number of centroids for the vector index. (default: 1) + + Example: + .. code-block:: python + + from arango import ArangoClient + from langchain_community.embeddings.openai import OpenAIEmbeddings + from langchain_community.vectorstores.arangodb_vector import ArangoDBVector + + db = ArangoClient("http://localhost:8529").db("test", username="root", password="openSesame") + + vector_store = ArangoDBVector( + embedding=OpenAIEmbeddings(), database=db + ) + + texts = ["hello world", "hello langchain", "hello arangodb"] + + vector_store.add_texts(texts) + + print(vector_store.similarity_search("arangodb", k=1)) + """ + + def __init__( + self, + embedding: Embeddings, + *, + database: "Database", + embedding_dimension: int, + search_type: SearchType = DEFAULT_SEARCH_TYPE, + collection_name: str = "documents", + index_name: str = "vector_index", + text_field: str = "text", + embedding_field: str = "embedding", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + num_centroids: int = 1, + ): + if not ARANGO_INSTALLED: + m = "ArangoDB not installed, please install with `pip install python-arango`." + raise ImportError(m) + + # TODO: Enable when ready + # if version.parse(database.version()) < version.parse("3.12.0"): + # raise ValueError("ArangoDB version must be 3.12.0 or greater") + + if search_type not in [SearchType.VECTOR]: + raise ValueError("search_type must be 'vector'") + + if distance_strategy not in [ + DistanceStrategy.COSINE, + DistanceStrategy.EUCLIDEAN_DISTANCE, + ]: + raise ValueError("distance_strategy must be 'COSINE' or 'EUCLIDEAN_DISTANCE'") + + self.db = database + self.embedding = embedding + self.collection_name = collection_name + self.index_name = index_name + self.embedding_field = embedding_field + self.text_field = text_field + self.distance_strategy = DISTANCE_MAPPING[distance_strategy] + self.embedding_dimension = embedding_dimension + self.num_centroids = num_centroids + self.index_name = index_name + + if not database.has_collection(collection_name): + database.create_collection(collection_name) + + self.collection = database.collection(self.collection_name) + + @property + def embeddings(self) -> Embeddings: + return self.embedding + + def add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + if ids is None: + ids = [str(uuid4()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + to_insert = [ + { + "_key": id_, + self.text_field: text, + self.embedding_field: embedding, + "metadata": metadata, + } + for id_, text, embedding, metadata in zip(ids, texts, embeddings, metadatas) + ] + + self.collection.import_bulk(to_insert, on_duplicate="update", **kwargs) + + if self.index_name not in [index["name"] for index in self.collection.indexes()]: + self.collection.add_index( + { + "name": self.index_name, + "type": "vector", + "fields": [self.embedding_field], + "params": { + "metric": self.distance_strategy, + "dimensions": self.embedding_dimension, + "nLists": self.num_centroids, + }, + } + ) + + return ids + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Add texts to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids to associate with the texts. + + Returns: + List of ids from adding the texts into the vectorstore. + """ + embeddings = self.embedding.embed_documents(list(texts)) + return self.add_embeddings(texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs) + + def similarity_search( + self, + query: str, + k: int = 4, + return_full_doc: bool = True, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with ArangoDB. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + return_full_doc (bool): Whether to return the full document. + If false, will just return the _key. Defaults to True. + + Returns: + List of Documents most similar to the query. + """ + embedding = self.embedding.embed_query(query) + return self.similarity_search_by_vector(embedding, k, return_full_doc) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + return_full_doc: bool = True, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + return_full_doc (bool): Whether to return the full document. + If false, will just return the _key. Defaults to True. + + Returns: + List of Documents most similar to the query vector. + """ + docs_and_scores = self.similarity_search_by_vector_with_score( + embedding=embedding, k=k, return_full_doc=return_full_doc, **kwargs + ) + + return [doc for doc, _ in docs_and_scores] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + return_full_doc: bool = True, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + return_full_doc (bool): Whether to return the full document. + If false, will just return the _key. Defaults to True. + + Returns: + List of Documents most similar to the query and score for each + """ + embedding = self.embedding.embed_query(query) + result = self.similarity_search_by_vector_with_score( + embedding=embedding, + k=k, + query=query, + return_full_doc=return_full_doc, + **kwargs, + ) + return result + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + k: int = 4, + return_full_doc: bool = False, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + return_full_doc (bool): Whether to return the full document. + If false, will just return the _key. Defaults to True. + + Returns: + List of Documents most similar to the query vector. + """ + if self.distance_strategy == "cosine": + sort_func = "APPROX_NEAR_COSINE" + elif self.distance_strategy == "l2": + sort_func = "APPROX_NEAR_L2" + else: + raise ValueError(f"Unsupported metric: {self.distance_strategy}") + + aql = f""" + FOR doc IN @@collection + LET score = {sort_func}(doc.{self.embedding_field}, @embedding) + SORT score DESC + LIMIT @k + LET data = @return_full_doc ? doc : {{'_key': doc._key, {self.text_field}: doc.{self.text_field}}} + RETURN {{data, score}} + """ + + bind_vars = { + "@collection": self.collection_name, + "embedding": embedding, + "k": k, + "return_full_doc": return_full_doc, + } + + cursor = self.db.aql.execute(aql, bind_vars=bind_vars) + + results = [] + for result in cursor: + page_content = result["data"].pop(self.text_field) + results.append((Document(page_content=page_content, **result["data"]), result["score"])) + + return results + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete by vector ID or other criteria. + + Args: + ids: List of ids to delete. + **kwargs: Other keyword arguments that can be used to delete vectors. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + for result in self.collection.delete_many(ids, **kwargs): + if isinstance(result, ArangoServerError): + print(result) + return False + + return True + + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: + """Get documents by their IDs. + + Args: + ids: List of ids to get. + + Returns: + List of Documents with the given ids. + """ + return self.collection.get_many(ids) + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: search query text. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + + Returns: + List of Documents selected by maximal marginal relevance. + """ + # Embed the query + query_embedding = self.embedding.embed_query(query) + + # Fetch the initial documents + docs_with_scores = self.similarity_search_by_vector_with_score( + embedding=query_embedding, + k=fetch_k, + return_full_doc=True, + **kwargs, + ) + + # Get the embeddings for the fetched documents + embeddings = [doc[self.embedding_field] for doc, _ in docs_with_scores] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + np.array(query_embedding), embeddings, lambda_mult=lambda_mult, k=k + ) + + selected_docs = [docs_with_scores[i][0] for i in selected_indices] + + # Remove embedding values from metadata + for doc in selected_docs: + del doc[self.embedding_field] + + return selected_docs + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided + # in vectorstore constructor + if self._distance_strategy == DistanceStrategy.COSINE: + return lambda x: x + elif self._distance_strategy == DistanceStrategy.L2: + return lambda x: x + else: + raise ValueError( + "No supported normalization function" + f" for distance_strategy of {self._distance_strategy}." + "Consider providing relevance_score_fn to PGVector constructor." + ) + + @classmethod + def from_texts( + cls: Type[ArangoVector], + texts: List[str], + embedding: Embeddings, + database: "Database", + search_type: SearchType = DEFAULT_SEARCH_TYPE, + collection_name: str = "documents", + index_name: str = "vector_index", + text_field: str = "text", + embedding_field: str = "embedding", + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + num_centroids: int = 1, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> ArangoVector: + """ + Return ArangoDBVector initialized from texts, embeddings and a database. + """ + embeddings = embedding.embed_documents(list(texts)) + + embedding_dimension = len(embeddings[0]) + + store = cls( + embedding, + database=database, + collection_name=collection_name, + embedding_dimension=embedding_dimension, + search_type=search_type, + index_name=index_name, + text_field=text_field, + embedding_field=embedding_field, + distance_strategy=distance_strategy, + num_centroids=num_centroids, + **kwargs, + ) + + store.add_embeddings(texts, embeddings, metadatas=metadatas, ids=ids, **kwargs) + + return store