From 2105f55bb35e2e0121b26b34def8e7d3a928bc32 Mon Sep 17 00:00:00 2001 From: leonbi100 Date: Wed, 18 Dec 2024 14:19:29 -0500 Subject: [PATCH] Update API to be uniform --- .../src/databricks_langchain/vector_search.py | 101 +++++++++--------- .../langchain/tests/test_vector_search.py | 4 +- 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/vector_search.py b/integrations/langchain/src/databricks_langchain/vector_search.py index 94fee73..8871086 100644 --- a/integrations/langchain/src/databricks_langchain/vector_search.py +++ b/integrations/langchain/src/databricks_langchain/vector_search.py @@ -6,6 +6,22 @@ from langchain_core.embeddings import Embeddings from langchain.tools.retriever import create_retriever_tool +from typing import Any, Dict, Optional, Type + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) + +class VectorSearchRetrieverToolInput(BaseModel): + query: str = Field(description="query used to search the index") + class VectorSearchRetrieverTool(): """ A utility class to create a vector search-based retrieval tool for querying indexed embeddings. @@ -13,62 +29,51 @@ class VectorSearchRetrieverTool(): for building a retriever tool for agents. Parameters: + index_name (str): + The name of the index to use. Format: “catalog.schema.index”. endpoint: + num_results (int): + The number of results to return. Defaults to 10. + columns (Optional[List[str]]): + The list of column names to get when doing the search. Defaults to [primary_key, text_column]. + filters (Optional[Dict[str, Any]]): + The filters to apply to the search. Defaults to None. + query_type (str): + The type of query to run. Defaults to "ANN". tool_name (str): The name of the retrieval tool to be created. This will be passed to the language model, so should be unique and somewhat descriptive. tool_description (str): A description of the tool's functionality. This will be passed to the language model, so should be descriptive. - - index_name (str): - The name of the index to use. Format: “catalog.schema.index”. endpoint: - endpoint (Optional[str]): - The name of the Databricks Vector Search endpoint. If not specified, the endpoint name is - automatically inferred based on the index name. - embedding (Optional[Embeddings]): - The embedding model. Required for direct-access index or delta-sync index with self-managed embeddings. - text_column (Optional[str]): - The name of the text column to use for the embeddings. Required for direct-access index or - delta-sync index with self-managed embeddings. Make sure the text column specified is in the index. - columns (Optional[List[str]]): - The list of column names to get when doing the search. Defaults to [primary_key, text_column]. - - search_type (Optional[str]): Defines the type of search that the Retriever should perform. - Defaults to “similarity” (default). - search_kwargs (Optional[Dict]): Keyword arguments to pass to the search function. """ - def __new__( - cls, - tool_name: str, - tool_description: str, + name: str = Field(description="The name of the tool") + description: str = Field(description="The description of the tool") + args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput + def __init__( + self, index_name: str, - endpoint: Optional[str] = None, - embedding: Optional[Embeddings] = None, - text_column: Optional[str] = None, + num_results: int = 10, + *, columns: Optional[List[str]] = None, - search_type: Optional[str] = None, - search_kwargs: Optional[dict] = None, + filters: Optional[Dict[str, Any]] = None, + query_type: str = "ANN", + tool_name: Optional[str] = None, + tool_description: Optional[str], # TODO: By default the UC metadata for description, how do I get this info? Call using client? ): - vector_store_kwargs = {"index_name": index_name} - if endpoint is not None: - vector_store_kwargs["endpoint"] = endpoint - if embedding is not None: - vector_store_kwargs["embedding"] = embedding - if text_column is not None: - vector_store_kwargs["text_column"] = text_column - if columns is not None: - vector_store_kwargs["columns"] = columns - vector_store = DatabricksVectorSearch(**vector_store_kwargs) - - retriever_kwargs = {} - if search_type is not None: - retriever_kwargs["search_type"] = search_type - if search_kwargs is not None: - retriever_kwargs["search_kwargs"] = search_kwargs + # Use the index name as the tool name if no tool name is provided + self.name = index_name + if tool_name: + self.name = tool_name + self.num_results = num_results + self.columns = columns + self.filters = filters + self.query_type = query_type + self.description = tool_description + self.vector_store = DatabricksVectorSearch(index_name=index_name) - retriever = vector_store.as_retriever(**retriever_kwargs) - return create_retriever_tool( - retriever, - tool_name, - tool_description, - ) + def _run( + self, + query: str + ) -> str: + """Use the tool.""" + self.vector_store.similarity_search(query, self.num_results, self.columns, self.filters, self.query_type) diff --git a/integrations/langchain/tests/test_vector_search.py b/integrations/langchain/tests/test_vector_search.py index 5a707b1..b3a4892 100644 --- a/integrations/langchain/tests/test_vector_search.py +++ b/integrations/langchain/tests/test_vector_search.py @@ -12,10 +12,10 @@ def init_vector_search_tool( index_name: str, columns: Optional[List[str]] = None ) -> VectorSearchRetrieverTool: kwargs: Dict[str, Any] = { - "tool_name": "test_tool", - "tool_description": "Test tool for vector search", "index_name": index_name, "columns": columns, + "tool_name": "test_tool", + "tool_description": "Test tool for vector search", } if index_name != DELTA_SYNC_INDEX: kwargs.update(