Skip to content

Commit

Permalink
Update API to be uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
leonbi100 committed Dec 18, 2024
1 parent fe09b25 commit 2105f55
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 50 deletions.
101 changes: 53 additions & 48 deletions integrations/langchain/src/databricks_langchain/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,74 @@
from langchain_core.embeddings import Embeddings
from langchain.tools.retriever import create_retriever_tool

from typing import Any, Dict, Optional, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)

class VectorSearchRetrieverToolInput(BaseModel):
query: str = Field(description="query used to search the index")

class VectorSearchRetrieverTool():
"""
A utility class to create a vector search-based retrieval tool for querying indexed embeddings.
This class integrates with a Databricks Vector Search and provides a convenient interface
for building a retriever tool for agents.
Parameters:
index_name (str):
The name of the index to use. Format: “catalog.schema.index”. endpoint:
num_results (int):
The number of results to return. Defaults to 10.
columns (Optional[List[str]]):
The list of column names to get when doing the search. Defaults to [primary_key, text_column].
filters (Optional[Dict[str, Any]]):
The filters to apply to the search. Defaults to None.
query_type (str):
The type of query to run. Defaults to "ANN".
tool_name (str):
The name of the retrieval tool to be created. This will be passed to the language model,
so should be unique and somewhat descriptive.
tool_description (str):
A description of the tool's functionality. This will be passed to the language model,
so should be descriptive.
index_name (str):
The name of the index to use. Format: “catalog.schema.index”. endpoint:
endpoint (Optional[str]):
The name of the Databricks Vector Search endpoint. If not specified, the endpoint name is
automatically inferred based on the index name.
embedding (Optional[Embeddings]):
The embedding model. Required for direct-access index or delta-sync index with self-managed embeddings.
text_column (Optional[str]):
The name of the text column to use for the embeddings. Required for direct-access index or
delta-sync index with self-managed embeddings. Make sure the text column specified is in the index.
columns (Optional[List[str]]):
The list of column names to get when doing the search. Defaults to [primary_key, text_column].
search_type (Optional[str]): Defines the type of search that the Retriever should perform.
Defaults to “similarity” (default).
search_kwargs (Optional[Dict]): Keyword arguments to pass to the search function.
"""
def __new__(
cls,
tool_name: str,
tool_description: str,
name: str = Field(description="The name of the tool")
description: str = Field(description="The description of the tool")
args_schema: Type[BaseModel] = VectorSearchRetrieverToolInput
def __init__(
self,
index_name: str,
endpoint: Optional[str] = None,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
num_results: int = 10,
*,
columns: Optional[List[str]] = None,
search_type: Optional[str] = None,
search_kwargs: Optional[dict] = None,
filters: Optional[Dict[str, Any]] = None,
query_type: str = "ANN",
tool_name: Optional[str] = None,
tool_description: Optional[str], # TODO: By default the UC metadata for description, how do I get this info? Call using client?
):
vector_store_kwargs = {"index_name": index_name}
if endpoint is not None:
vector_store_kwargs["endpoint"] = endpoint
if embedding is not None:
vector_store_kwargs["embedding"] = embedding
if text_column is not None:
vector_store_kwargs["text_column"] = text_column
if columns is not None:
vector_store_kwargs["columns"] = columns
vector_store = DatabricksVectorSearch(**vector_store_kwargs)

retriever_kwargs = {}
if search_type is not None:
retriever_kwargs["search_type"] = search_type
if search_kwargs is not None:
retriever_kwargs["search_kwargs"] = search_kwargs
# Use the index name as the tool name if no tool name is provided
self.name = index_name
if tool_name:
self.name = tool_name
self.num_results = num_results
self.columns = columns
self.filters = filters
self.query_type = query_type
self.description = tool_description
self.vector_store = DatabricksVectorSearch(index_name=index_name)

retriever = vector_store.as_retriever(**retriever_kwargs)
return create_retriever_tool(
retriever,
tool_name,
tool_description,
)
def _run(
self,
query: str
) -> str:
"""Use the tool."""
self.vector_store.similarity_search(query, self.num_results, self.columns, self.filters, self.query_type)
4 changes: 2 additions & 2 deletions integrations/langchain/tests/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def init_vector_search_tool(
index_name: str, columns: Optional[List[str]] = None
) -> VectorSearchRetrieverTool:
kwargs: Dict[str, Any] = {
"tool_name": "test_tool",
"tool_description": "Test tool for vector search",
"index_name": index_name,
"columns": columns,
"tool_name": "test_tool",
"tool_description": "Test tool for vector search",
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
Expand Down

0 comments on commit 2105f55

Please sign in to comment.